diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 80283643861ac..e077eace74375 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -71,6 +71,7 @@ exportMethods( "unpersist", "value", "values", + "zipPartitions", "zipRDD", "zipWithIndex", "zipWithUniqueId" diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 1662d6bb3b1ac..a3a0421a0746d 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -66,6 +66,11 @@ setMethod("initialize", "RDD", function(.Object, jrdd, serializedMode, .Object }) +setMethod("show", "RDD", + function(.Object) { + cat(paste(callJMethod(.Object@jrdd, "toString"), "\n", sep="")) + }) + setMethod("initialize", "PipelinedRDD", function(.Object, prev, func, jrdd_val) { .Object@env <- new.env() .Object@env$isCached <- FALSE @@ -1590,3 +1595,49 @@ setMethod("intersection", keys(filterRDD(cogroup(rdd1, rdd2, numPartitions = numPartitions), filterFunction)) }) + +#' Zips an RDD's partitions with one (or more) RDD(s). +#' Same as zipPartitions in Spark. +#' +#' @param ... RDDs to be zipped. +#' @param func A function to transform zipped partitions. +#' @return A new RDD by applying a function to the zipped partitions. +#' Assumes that all the RDDs have the *same number of partitions*, but +#' does *not* require them to have the same number of elements in each partition. +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 +#' rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 +#' rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 +#' collect(zipPartitions(rdd1, rdd2, rdd3, +#' func = function(x, y, z) { list(list(x, y, z))} )) +#' # list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6))) +#'} +#' @rdname zipRDD +#' @aliases zipPartitions,RDD +setMethod("zipPartitions", + "RDD", + function(..., func) { + rrdds <- list(...) + if (length(rrdds) == 1) { + return(rrdds[[1]]) + } + nPart <- sapply(rrdds, numPartitions) + if (length(unique(nPart)) != 1) { + stop("Can only zipPartitions RDDs which have the same number of partitions.") + } + + rrdds <- lapply(rrdds, function(rdd) { + mapPartitionsWithIndex(rdd, function(partIndex, part) { + print(length(part)) + list(list(partIndex, part)) + }) + }) + union.rdd <- Reduce(unionRDD, rrdds) + zipped.rdd <- values(groupByKey(union.rdd, numPartitions = nPart[1])) + res <- mapPartitions(zipped.rdd, function(plist) { + do.call(func, plist[[1]]) + }) + res + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 34dbe84051c50..e88729387ef95 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -217,6 +217,11 @@ setGeneric("unpersist", function(x, ...) { standardGeneric("unpersist") }) #' @export setGeneric("zipRDD", function(x, other) { standardGeneric("zipRDD") }) +#' @rdname zipRDD +#' @export +setGeneric("zipPartitions", function(..., func) { standardGeneric("zipPartitions") }, + signature = "...") + #' @rdname zipWithIndex #' @seealso zipWithUniqueId #' @export diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/test_binary_function.R index c15553ba28517..6785a7bdae8cb 100644 --- a/R/pkg/inst/tests/test_binary_function.R +++ b/R/pkg/inst/tests/test_binary_function.R @@ -66,3 +66,36 @@ test_that("cogroup on two RDDs", { expect_equal(sortKeyValueList(actual), sortKeyValueList(expected)) }) + +test_that("zipPartitions() on RDDs", { + rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 + rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 + rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 + actual <- collect(zipPartitions(rdd1, rdd2, rdd3, + func = function(x, y, z) { list(list(x, y, z))} )) + expect_equal(actual, + list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6)))) + + mockFile = c("Spark is pretty.", "Spark is awesome.") + fileName <- tempfile(pattern="spark-test", fileext=".tmp") + writeLines(mockFile, fileName) + + rdd <- textFile(sc, fileName, 1) + actual <- collect(zipPartitions(rdd, rdd, + func = function(x, y) { list(paste(x, y, sep = "\n")) })) + expected <- list(paste(mockFile, mockFile, sep = "\n")) + expect_equal(actual, expected) + + rdd1 <- parallelize(sc, 0:1, 1) + actual <- collect(zipPartitions(rdd1, rdd, + func = function(x, y) { list(x + nchar(y)) })) + expected <- list(0:1 + nchar(mockFile)) + expect_equal(actual, expected) + + rdd <- map(rdd, function(x) { x }) + actual <- collect(zipPartitions(rdd, rdd1, + func = function(x, y) { list(y + nchar(x)) })) + expect_equal(actual, expected) + + unlink(fileName) +}) diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R index d55af93e3e50a..03207353c31c6 100644 --- a/R/pkg/inst/tests/test_rdd.R +++ b/R/pkg/inst/tests/test_rdd.R @@ -759,6 +759,11 @@ test_that("collectAsMap() on a pairwise RDD", { expect_equal(vals, list(`1` = "a", `2` = "b")) }) +test_that("show()", { + rdd <- parallelize(sc, list(1:10)) + expect_output(show(rdd), "ParallelCollectionRDD\\[\\d+\\] at parallelize at RRDD\\.scala:\\d+") +}) + test_that("sampleByKey() on pairwise RDDs", { rdd <- parallelize(sc, 1:2000) pairsRDD <- lapply(rdd, function(x) { if (x %% 2 == 0) list("a", x) else list("b", x) }) diff --git a/assembly/pom.xml b/assembly/pom.xml index 20593e710dedb..2b4d0a990bf22 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -194,7 +194,6 @@ org.apache.maven.plugins maven-assembly-plugin - 2.4 dist diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index 3d068dd3a2739..db09fa27e51a6 100644 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -61,7 +61,10 @@ if not "x%JAVA_HOME%"=="x" set RUNNER=%JAVA_HOME%\bin\java rem The launcher library prints the command to be executed in a single line suitable for being rem executed by the batch interpreter. So read all the output of the launcher into a variable. -for /f "tokens=*" %%i in ('cmd /C ""%RUNNER%" -cp %LAUNCH_CLASSPATH% org.apache.spark.launcher.Main %*"') do ( +set LAUNCHER_OUTPUT=%temp%\spark-class-launcher-output-%RANDOM%.txt +"%RUNNER%" -cp %LAUNCH_CLASSPATH% org.apache.spark.launcher.Main %* > %LAUNCHER_OUTPUT% +for /f "tokens=*" %%i in (%LAUNCHER_OUTPUT%) do ( set SPARK_CMD=%%i ) +del %LAUNCHER_OUTPUT% %SPARK_CMD% diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index 67f81d33361e1..43c4288912b18 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -3,7 +3,7 @@ # This file is sourced when running various Spark programs. # Copy it as spark-env.sh and edit that to configure Spark for your site. -# Options read when launching programs locally with +# Options read when launching programs locally with # ./bin/run-example or ./bin/spark-submit # - HADOOP_CONF_DIR, to point Spark towards Hadoop configuration files # - SPARK_LOCAL_IP, to set the IP address Spark binds to on this node @@ -39,6 +39,7 @@ # - SPARK_WORKER_DIR, to set the working directory of worker processes # - SPARK_WORKER_OPTS, to set config properties only for the worker (e.g. "-Dx=y") # - SPARK_HISTORY_OPTS, to set config properties only for the history server (e.g. "-Dx=y") +# - SPARK_SHUFFLE_OPTS, to set config properties only for the external shuffle service (e.g. "-Dx=y") # - SPARK_DAEMON_JAVA_OPTS, to set config properties for all daemons (e.g. "-Dx=y") # - SPARK_PUBLIC_DNS, to set the public dns name of the master or workers diff --git a/core/pom.xml b/core/pom.xml index 5e89d548cd47f..2dfb00d7ecf26 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -95,6 +95,11 @@ spark-network-shuffle_${scala.binary.version} ${project.version} + + org.apache.spark + spark-unsafe_${scala.binary.version} + ${project.version} + net.java.dev.jets3t jets3t @@ -478,7 +483,6 @@ org.codehaus.mojo exec-maven-plugin - 1.3.2 sparkr-pkg diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 68d05d5b02537..f2b024ff6cb67 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -76,13 +76,15 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) private var timeoutCheckingTask: ScheduledFuture[_] = null - private val timeoutCheckingThread = - ThreadUtils.newDaemonSingleThreadScheduledExecutor("heartbeat-timeout-checking-thread") + // "eventLoopThread" is used to run some pretty fast actions. The actions running in it should not + // block the thread for a long time. + private val eventLoopThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("heartbeat-receiver-event-loop-thread") private val killExecutorThread = ThreadUtils.newDaemonSingleThreadExecutor("kill-executor-thread") override def onStart(): Unit = { - timeoutCheckingTask = timeoutCheckingThread.scheduleAtFixedRate(new Runnable { + timeoutCheckingTask = eventLoopThread.scheduleAtFixedRate(new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { Option(self).foreach(_.send(ExpireDeadHosts)) } @@ -99,11 +101,15 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case heartbeat @ Heartbeat(executorId, taskMetrics, blockManagerId) => if (scheduler != null) { - val unknownExecutor = !scheduler.executorHeartbeatReceived( - executorId, taskMetrics, blockManagerId) - val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) executorLastSeen(executorId) = System.currentTimeMillis() - context.reply(response) + eventLoopThread.submit(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + val unknownExecutor = !scheduler.executorHeartbeatReceived( + executorId, taskMetrics, blockManagerId) + val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor) + context.reply(response) + } + }) } else { // Because Executor will sleep several seconds before sending the first "Heartbeat", this // case rarely happens. However, if it really happens, log it and ask the executor to @@ -125,7 +131,9 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) if (sc.supportDynamicAllocation) { // Asynchronously kill the executor to avoid blocking the current thread killExecutorThread.submit(new Runnable { - override def run(): Unit = sc.killExecutor(executorId) + override def run(): Unit = Utils.tryLogNonFatalError { + sc.killExecutor(executorId) + } }) } executorLastSeen.remove(executorId) @@ -137,7 +145,7 @@ private[spark] class HeartbeatReceiver(sc: SparkContext) if (timeoutCheckingTask != null) { timeoutCheckingTask.cancel(true) } - timeoutCheckingThread.shutdownNow() + eventLoopThread.shutdownNow() killExecutorThread.shutdownNow() } } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index d65c94e410662..16072283edbe9 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -106,7 +106,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging */ protected def askTracker[T: ClassTag](message: Any): T = { try { - trackerEndpoint.askWithReply[T](message) + trackerEndpoint.askWithRetry[T](message) } catch { case e: Exception => logError("Error communicating with MapOutputTracker", e) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index c1996e08756a6..a8fc90ad2050e 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -211,7 +211,74 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { Utils.timeStringAsMs(get(key, defaultValue)) } + /** + * Get a size parameter as bytes; throws a NoSuchElementException if it's not set. If no + * suffix is provided then bytes are assumed. + * @throws NoSuchElementException + */ + def getSizeAsBytes(key: String): Long = { + Utils.byteStringAsBytes(get(key)) + } + + /** + * Get a size parameter as bytes, falling back to a default if not set. If no + * suffix is provided then bytes are assumed. + */ + def getSizeAsBytes(key: String, defaultValue: String): Long = { + Utils.byteStringAsBytes(get(key, defaultValue)) + } + + /** + * Get a size parameter as Kibibytes; throws a NoSuchElementException if it's not set. If no + * suffix is provided then Kibibytes are assumed. + * @throws NoSuchElementException + */ + def getSizeAsKb(key: String): Long = { + Utils.byteStringAsKb(get(key)) + } + + /** + * Get a size parameter as Kibibytes, falling back to a default if not set. If no + * suffix is provided then Kibibytes are assumed. + */ + def getSizeAsKb(key: String, defaultValue: String): Long = { + Utils.byteStringAsKb(get(key, defaultValue)) + } + + /** + * Get a size parameter as Mebibytes; throws a NoSuchElementException if it's not set. If no + * suffix is provided then Mebibytes are assumed. + * @throws NoSuchElementException + */ + def getSizeAsMb(key: String): Long = { + Utils.byteStringAsMb(get(key)) + } + + /** + * Get a size parameter as Mebibytes, falling back to a default if not set. If no + * suffix is provided then Mebibytes are assumed. + */ + def getSizeAsMb(key: String, defaultValue: String): Long = { + Utils.byteStringAsMb(get(key, defaultValue)) + } + + /** + * Get a size parameter as Gibibytes; throws a NoSuchElementException if it's not set. If no + * suffix is provided then Gibibytes are assumed. + * @throws NoSuchElementException + */ + def getSizeAsGb(key: String): Long = { + Utils.byteStringAsGb(get(key)) + } + /** + * Get a size parameter as Gibibytes, falling back to a default if not set. If no + * suffix is provided then Gibibytes are assumed. + */ + def getSizeAsGb(key: String, defaultValue: String): Long = { + Utils.byteStringAsGb(get(key, defaultValue)) + } + /** Get a parameter as an Option */ def getOption(key: String): Option[String] = { Option(settings.get(key)).orElse(getDeprecatedConfig(key, this)) @@ -407,7 +474,13 @@ private[spark] object SparkConf extends Logging { "The spark.cache.class property is no longer being used! Specify storage levels using " + "the RDD.persist() method instead."), DeprecatedConfig("spark.yarn.user.classpath.first", "1.3", - "Please use spark.{driver,executor}.userClassPathFirst instead.")) + "Please use spark.{driver,executor}.userClassPathFirst instead."), + DeprecatedConfig("spark.kryoserializer.buffer.mb", "1.4", + "Please use spark.kryoserializer.buffer instead. The default value for " + + "spark.kryoserializer.buffer.mb was previously specified as '0.064'. Fractional values " + + "are no longer accepted. To specify the equivalent now, one may use '64k'.") + ) + Map(configs.map { cfg => (cfg.key -> cfg) }:_*) } @@ -432,6 +505,21 @@ private[spark] object SparkConf extends Logging { AlternateConfig("spark.yarn.applicationMaster.waitTries", "1.3", // Translate old value to a duration, with 10s wait time per try. translation = s => s"${s.toLong * 10}s")), + "spark.reducer.maxSizeInFlight" -> Seq( + AlternateConfig("spark.reducer.maxMbInFlight", "1.4")), + "spark.kryoserializer.buffer" -> + Seq(AlternateConfig("spark.kryoserializer.buffer.mb", "1.4", + translation = s => s"${s.toDouble * 1000}k")), + "spark.kryoserializer.buffer.max" -> Seq( + AlternateConfig("spark.kryoserializer.buffer.max.mb", "1.4")), + "spark.shuffle.file.buffer" -> Seq( + AlternateConfig("spark.shuffle.file.buffer.kb", "1.4")), + "spark.executor.logs.rolling.maxSize" -> Seq( + AlternateConfig("spark.executor.logs.rolling.size.maxBytes", "1.4")), + "spark.io.compression.snappy.blockSize" -> Seq( + AlternateConfig("spark.io.compression.snappy.block.size", "1.4")), + "spark.io.compression.lz4.blockSize" -> Seq( + AlternateConfig("spark.io.compression.lz4.block.size", "1.4")), "spark.rpc.numRetries" -> Seq( AlternateConfig("spark.akka.num.retries", "1.4")), "spark.rpc.retry.wait" -> Seq( diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 86269eac52db0..5ae8fb81de809 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -223,6 +223,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli private var _listenerBusStarted: Boolean = false private var _jars: Seq[String] = _ private var _files: Seq[String] = _ + private var _shutdownHookRef: AnyRef = _ /* ------------------------------------------------------------------------------------- * | Accessors and public fields. These provide access to the internal state of the | @@ -517,6 +518,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _taskScheduler.postStartHook() _env.metricsSystem.registerSource(new DAGSchedulerSource(dagScheduler)) _env.metricsSystem.registerSource(new BlockManagerSource(_env.blockManager)) + + // Make sure the context is stopped if the user forgets about it. This avoids leaving + // unfinished event logs around after the JVM exits cleanly. It doesn't help if the JVM + // is killed, though. + _shutdownHookRef = Utils.addShutdownHook(Utils.SPARK_CONTEXT_SHUTDOWN_PRIORITY) { () => + logInfo("Invoking stop() from shutdown hook") + stop() + } } catch { case NonFatal(e) => logError("Error initializing SparkContext.", e) @@ -546,7 +555,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli SparkEnv.executorActorSystemName, RpcAddress(host, port), ExecutorEndpoint.EXECUTOR_ENDPOINT_NAME) - Some(endpointRef.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump)) + Some(endpointRef.askWithRetry[Array[ThreadStackTrace]](TriggerThreadDump)) } } catch { case e: Exception => @@ -1055,7 +1064,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** Build the union of a list of RDDs. */ def union[T: ClassTag](rdds: Seq[RDD[T]]): RDD[T] = { val partitioners = rdds.flatMap(_.partitioner).toSet - if (partitioners.size == 1) { + if (rdds.forall(_.partitioner.isDefined) && partitioners.size == 1) { new PartitionerAwareUnionRDD(this, rdds) } else { new UnionRDD(this, rdds) @@ -1387,6 +1396,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Register an RDD to be persisted in memory and/or disk storage */ private[spark] def persistRDD(rdd: RDD[_]) { + _executorAllocationManager.foreach { _ => + logWarning( + s"Dynamic allocation currently does not support cached RDDs. Cached data for RDD " + + s"${rdd.id} will be lost when executors are removed.") + } persistentRdds(rdd.id) = rdd } @@ -1481,6 +1495,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli logInfo("SparkContext already stopped.") return } + if (_shutdownHookRef != null) { + Utils.removeShutdownHook(_shutdownHookRef) + } postApplicationEnd() _ui.foreach(_.stop()) @@ -1891,7 +1908,7 @@ object SparkContext extends Logging { * * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK. */ - private val activeContext: AtomicReference[SparkContext] = + private val activeContext: AtomicReference[SparkContext] = new AtomicReference[SparkContext](null) /** @@ -1944,11 +1961,11 @@ object SparkContext extends Logging { } /** - * This function may be used to get or instantiate a SparkContext and register it as a - * singleton object. Because we can only have one active SparkContext per JVM, - * this is useful when applications may wish to share a SparkContext. + * This function may be used to get or instantiate a SparkContext and register it as a + * singleton object. Because we can only have one active SparkContext per JVM, + * this is useful when applications may wish to share a SparkContext. * - * Note: This function cannot be used to create multiple SparkContext instances + * Note: This function cannot be used to create multiple SparkContext instances * even if multiple contexts are allowed. */ def getOrCreate(config: SparkConf): SparkContext = { @@ -1961,17 +1978,17 @@ object SparkContext extends Logging { activeContext.get() } } - + /** - * This function may be used to get or instantiate a SparkContext and register it as a - * singleton object. Because we can only have one active SparkContext per JVM, + * This function may be used to get or instantiate a SparkContext and register it as a + * singleton object. Because we can only have one active SparkContext per JVM, * this is useful when applications may wish to share a SparkContext. - * + * * This method allows not passing a SparkConf (useful if just retrieving). - * - * Note: This function cannot be used to create multiple SparkContext instances - * even if multiple contexts are allowed. - */ + * + * Note: This function cannot be used to create multiple SparkContext instances + * even if multiple contexts are allowed. + */ def getOrCreate(): SparkContext = { getOrCreate(new SparkConf()) } diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 959aefabd8de4..0c4d28f786edd 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -40,6 +40,7 @@ import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinato import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} import org.apache.spark.storage._ +import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator} import org.apache.spark.util.{RpcUtils, Utils} /** @@ -69,6 +70,7 @@ class SparkEnv ( val sparkFilesDir: String, val metricsSystem: MetricsSystem, val shuffleMemoryManager: ShuffleMemoryManager, + val executorMemoryManager: ExecutorMemoryManager, val outputCommitCoordinator: OutputCommitCoordinator, val conf: SparkConf) extends Logging { @@ -382,6 +384,15 @@ object SparkEnv extends Logging { new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator)) outputCommitCoordinator.coordinatorRef = Some(outputCommitCoordinatorRef) + val executorMemoryManager: ExecutorMemoryManager = { + val allocator = if (conf.getBoolean("spark.unsafe.offHeap", false)) { + MemoryAllocator.UNSAFE + } else { + MemoryAllocator.HEAP + } + new ExecutorMemoryManager(allocator) + } + val envInstance = new SparkEnv( executorId, rpcEnv, @@ -398,6 +409,7 @@ object SparkEnv extends Logging { sparkFilesDir, metricsSystem, shuffleMemoryManager, + executorMemoryManager, outputCommitCoordinator, conf) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 7d7fe1a446313..d09e17dea0911 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -21,6 +21,7 @@ import java.io.Serializable import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics +import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util.TaskCompletionListener @@ -133,4 +134,9 @@ abstract class TaskContext extends Serializable { /** ::DeveloperApi:: */ @DeveloperApi def taskMetrics(): TaskMetrics + + /** + * Returns the manager for this task's managed memory. + */ + private[spark] def taskMemoryManager(): TaskMemoryManager } diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 337c8e4ebebcd..b4d572cb52313 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -18,6 +18,7 @@ package org.apache.spark import org.apache.spark.executor.TaskMetrics +import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} import scala.collection.mutable.ArrayBuffer @@ -27,6 +28,7 @@ private[spark] class TaskContextImpl( val partitionId: Int, override val taskAttemptId: Long, override val attemptNumber: Int, + override val taskMemoryManager: TaskMemoryManager, val runningLocally: Boolean = false, val taskMetrics: TaskMetrics = TaskMetrics.empty) extends TaskContext diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 23b02e60338fb..a0c9b5e63c744 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -74,7 +74,8 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) } else { None } - blockSize = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 + // Note: use getSizeAsKb (not bytes) to maintain compatiblity if no units are provided + blockSize = conf.getSizeAsKb("spark.broadcast.blockSize", "4m").toInt * 1024 } setConf(SparkEnv.get.conf) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala similarity index 59% rename from core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala rename to core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index b9798963bab0a..cd16f992a3c0a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -15,7 +15,9 @@ * limitations under the License. */ -package org.apache.spark.deploy.worker +package org.apache.spark.deploy + +import java.util.concurrent.CountDownLatch import org.apache.spark.{Logging, SparkConf, SecurityManager} import org.apache.spark.network.TransportContext @@ -23,6 +25,7 @@ import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.sasl.SaslRpcHandler import org.apache.spark.network.server.TransportServer import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler +import org.apache.spark.util.Utils /** * Provides a server from which Executors can read shuffle files (rather than reading directly from @@ -31,8 +34,8 @@ import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler * * Optionally requires SASL authentication in order to read. See [[SecurityManager]]. */ -private[worker] -class StandaloneWorkerShuffleService(sparkConf: SparkConf, securityManager: SecurityManager) +private[deploy] +class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityManager) extends Logging { private val enabled = sparkConf.getBoolean("spark.shuffle.service.enabled", false) @@ -51,16 +54,58 @@ class StandaloneWorkerShuffleService(sparkConf: SparkConf, securityManager: Secu /** Starts the external shuffle service if the user has configured us to. */ def startIfEnabled() { if (enabled) { - require(server == null, "Shuffle server already started") - logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl") - server = transportContext.createServer(port) + start() } } + /** Start the external shuffle service */ + def start() { + require(server == null, "Shuffle server already started") + logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl") + server = transportContext.createServer(port) + } + def stop() { - if (enabled && server != null) { + if (server != null) { server.close() server = null } } } + +/** + * A main class for running the external shuffle service. + */ +object ExternalShuffleService extends Logging { + @volatile + private var server: ExternalShuffleService = _ + + private val barrier = new CountDownLatch(1) + + def main(args: Array[String]): Unit = { + val sparkConf = new SparkConf + Utils.loadDefaultSparkProperties(sparkConf) + val securityManager = new SecurityManager(sparkConf) + + // we override this value since this service is started from the command line + // and we assume the user really wants it to be running + sparkConf.set("spark.shuffle.service.enabled", "true") + server = new ExternalShuffleService(sparkConf, securityManager) + server.start() + + installShutdownHook() + + // keep running until the process is terminated + barrier.await() + } + + private def installShutdownHook(): Unit = { + Runtime.getRuntime.addShutdownHook(new Thread("External Shuffle Service shutdown thread") { + override def run() { + logInfo("Shutting down shuffle service.") + server.stop() + barrier.countDown() + } + }) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index a7c89276a045e..c048b78910f38 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -32,7 +32,7 @@ import org.json4s._ import org.json4s.jackson.JsonMethods import org.apache.spark.{Logging, SparkConf, SparkContext} -import org.apache.spark.deploy.master.{RecoveryState, SparkCuratorUtil} +import org.apache.spark.deploy.master.RecoveryState import org.apache.spark.util.Utils /** diff --git a/core/src/main/scala/org/apache/spark/deploy/master/SparkCuratorUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala similarity index 89% rename from core/src/main/scala/org/apache/spark/deploy/master/SparkCuratorUtil.scala rename to core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala index 5b22481ea8c5f..b8d3993540220 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/SparkCuratorUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkCuratorUtil.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.deploy.master +package org.apache.spark.deploy import scala.collection.JavaConversions._ @@ -25,15 +25,17 @@ import org.apache.zookeeper.KeeperException import org.apache.spark.{Logging, SparkConf} -private[deploy] object SparkCuratorUtil extends Logging { +private[spark] object SparkCuratorUtil extends Logging { private val ZK_CONNECTION_TIMEOUT_MILLIS = 15000 private val ZK_SESSION_TIMEOUT_MILLIS = 60000 private val RETRY_WAIT_MILLIS = 5000 private val MAX_RECONNECT_ATTEMPTS = 3 - def newClient(conf: SparkConf): CuratorFramework = { - val ZK_URL = conf.get("spark.deploy.zookeeper.url") + def newClient( + conf: SparkConf, + zkUrlConf: String = "spark.deploy.zookeeper.url"): CuratorFramework = { + val ZK_URL = conf.get(zkUrlConf) val zk = CuratorFrameworkFactory.newClient(ZK_URL, ZK_SESSION_TIMEOUT_MILLIS, ZK_CONNECTION_TIMEOUT_MILLIS, new ExponentialBackoffRetry(RETRY_WAIT_MILLIS, MAX_RECONNECT_ATTEMPTS)) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 296a0764b8baf..b8ae4af18d1d1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -36,11 +36,11 @@ import org.apache.ivy.core.retrieve.RetrieveOptions import org.apache.ivy.core.settings.IvySettings import org.apache.ivy.plugins.matcher.GlobPatternMatcher import org.apache.ivy.plugins.resolver.{ChainResolver, IBiblioResolver} - import org.apache.spark.SPARK_VERSION import org.apache.spark.deploy.rest._ import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} + /** * Whether to submit, kill, or request the status of an application. * The latter two operations are currently supported only for standalone cluster mode. @@ -114,18 +114,20 @@ object SparkSubmit { } } - /** Kill an existing submission using the REST protocol. Standalone cluster mode only. */ + /** + * Kill an existing submission using the REST protocol. Standalone and Mesos cluster mode only. + */ private def kill(args: SparkSubmitArguments): Unit = { - new StandaloneRestClient() + new RestSubmissionClient() .killSubmission(args.master, args.submissionToKill) } /** * Request the status of an existing submission using the REST protocol. - * Standalone cluster mode only. + * Standalone and Mesos cluster mode only. */ private def requestStatus(args: SparkSubmitArguments): Unit = { - new StandaloneRestClient() + new RestSubmissionClient() .requestSubmissionStatus(args.master, args.submissionToRequestStatusFor) } @@ -252,6 +254,7 @@ object SparkSubmit { } val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER + val isMesosCluster = clusterManager == MESOS && deployMode == CLUSTER // Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files // too for packages that include Python code @@ -294,8 +297,9 @@ object SparkSubmit { // The following modes are not supported or applicable (clusterManager, deployMode) match { - case (MESOS, CLUSTER) => - printErrorAndExit("Cluster deploy mode is currently not supported for Mesos clusters.") + case (MESOS, CLUSTER) if args.isPython => + printErrorAndExit("Cluster deploy mode is currently not supported for python " + + "applications on Mesos clusters.") case (STANDALONE, CLUSTER) if args.isPython => printErrorAndExit("Cluster deploy mode is currently not supported for python " + "applications on standalone clusters.") @@ -377,15 +381,6 @@ object SparkSubmit { OptionAssigner(args.driverExtraLibraryPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.driver.extraLibraryPath"), - // Standalone cluster only - // Do not set CL arguments here because there are multiple possibilities for the main class - OptionAssigner(args.jars, STANDALONE, CLUSTER, sysProp = "spark.jars"), - OptionAssigner(args.ivyRepoPath, STANDALONE, CLUSTER, sysProp = "spark.jars.ivy"), - OptionAssigner(args.driverMemory, STANDALONE, CLUSTER, sysProp = "spark.driver.memory"), - OptionAssigner(args.driverCores, STANDALONE, CLUSTER, sysProp = "spark.driver.cores"), - OptionAssigner(args.supervise.toString, STANDALONE, CLUSTER, - sysProp = "spark.driver.supervise"), - // Yarn client only OptionAssigner(args.queue, YARN, CLIENT, sysProp = "spark.yarn.queue"), OptionAssigner(args.numExecutors, YARN, CLIENT, sysProp = "spark.executor.instances"), @@ -413,7 +408,15 @@ object SparkSubmit { OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS, ALL_DEPLOY_MODES, sysProp = "spark.cores.max"), OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, ALL_DEPLOY_MODES, - sysProp = "spark.files") + sysProp = "spark.files"), + OptionAssigner(args.jars, STANDALONE | MESOS, CLUSTER, sysProp = "spark.jars"), + OptionAssigner(args.driverMemory, STANDALONE | MESOS, CLUSTER, + sysProp = "spark.driver.memory"), + OptionAssigner(args.driverCores, STANDALONE | MESOS, CLUSTER, + sysProp = "spark.driver.cores"), + OptionAssigner(args.supervise.toString, STANDALONE | MESOS, CLUSTER, + sysProp = "spark.driver.supervise"), + OptionAssigner(args.ivyRepoPath, STANDALONE, CLUSTER, sysProp = "spark.jars.ivy") ) // In client mode, launch the application main class directly @@ -452,7 +455,7 @@ object SparkSubmit { // All Spark parameters are expected to be passed to the client through system properties. if (args.isStandaloneCluster) { if (args.useRest) { - childMainClass = "org.apache.spark.deploy.rest.StandaloneRestClient" + childMainClass = "org.apache.spark.deploy.rest.RestSubmissionClient" childArgs += (args.primaryResource, args.mainClass) } else { // In legacy standalone cluster mode, use Client as a wrapper around the user class @@ -496,6 +499,15 @@ object SparkSubmit { } } + if (isMesosCluster) { + assert(args.useRest, "Mesos cluster mode is only supported through the REST submission API") + childMainClass = "org.apache.spark.deploy.rest.RestSubmissionClient" + childArgs += (args.primaryResource, args.mainClass) + if (args.childArgs != null) { + childArgs ++= args.childArgs + } + } + // Load any properties specified through --conf and the default properties file for ((k, v) <- args.sparkProperties) { sysProps.getOrElseUpdate(k, v) @@ -722,13 +734,31 @@ private[deploy] object SparkSubmitUtils { /** * Extracts maven coordinates from a comma-delimited string * @param remoteRepos Comma-delimited string of remote repositories + * @param ivySettings The Ivy settings for this session * @return A ChainResolver used by Ivy to search for and resolve dependencies. */ - def createRepoResolvers(remoteRepos: Option[String]): ChainResolver = { + def createRepoResolvers(remoteRepos: Option[String], ivySettings: IvySettings): ChainResolver = { // We need a chain resolver if we want to check multiple repositories val cr = new ChainResolver cr.setName("list") + val localM2 = new IBiblioResolver + localM2.setM2compatible(true) + val m2Path = ".m2" + File.separator + "repository" + File.separator + localM2.setRoot(new File(System.getProperty("user.home"), m2Path).toURI.toString) + localM2.setUsepoms(true) + localM2.setName("local-m2-cache") + cr.add(localM2) + + val localIvy = new IBiblioResolver + localIvy.setRoot(new File(ivySettings.getDefaultIvyUserDir, + "local" + File.separator).toURI.toString) + val ivyPattern = Seq("[organisation]", "[module]", "[revision]", "[type]s", + "[artifact](-[classifier]).[ext]").mkString(File.separator) + localIvy.setPattern(ivyPattern) + localIvy.setName("local-ivy-cache") + cr.add(localIvy) + // the biblio resolver resolves POM declared dependencies val br: IBiblioResolver = new IBiblioResolver br.setM2compatible(true) @@ -761,8 +791,7 @@ private[deploy] object SparkSubmitUtils { /** * Output a comma-delimited list of paths for the downloaded jars to be added to the classpath - * (will append to jars in SparkSubmit). The name of the jar is given - * after a '!' by Ivy. It also sometimes contains '(bundle)' after '.jar'. Remove that as well. + * (will append to jars in SparkSubmit). * @param artifacts Sequence of dependencies that were resolved and retrieved * @param cacheDirectory directory where jars are cached * @return a comma-delimited list of paths for the dependencies @@ -771,10 +800,9 @@ private[deploy] object SparkSubmitUtils { artifacts: Array[AnyRef], cacheDirectory: File): String = { artifacts.map { artifactInfo => - val artifactString = artifactInfo.toString - val jarName = artifactString.drop(artifactString.lastIndexOf("!") + 1) + val artifact = artifactInfo.asInstanceOf[Artifact].getModuleRevisionId cacheDirectory.getAbsolutePath + File.separator + - jarName.substring(0, jarName.lastIndexOf(".jar") + 4) + s"${artifact.getOrganisation}_${artifact.getName}-${artifact.getRevision}.jar" }.mkString(",") } @@ -856,6 +884,7 @@ private[deploy] object SparkSubmitUtils { if (alternateIvyCache.trim.isEmpty) { new File(ivySettings.getDefaultIvyUserDir, "jars") } else { + ivySettings.setDefaultIvyUserDir(new File(alternateIvyCache)) ivySettings.setDefaultCache(new File(alternateIvyCache, "cache")) new File(alternateIvyCache, "jars") } @@ -865,7 +894,7 @@ private[deploy] object SparkSubmitUtils { // create a pattern matcher ivySettings.addMatcher(new GlobPatternMatcher) // create the dependency resolvers - val repoResolver = createRepoResolvers(remoteRepos) + val repoResolver = createRepoResolvers(remoteRepos, ivySettings) ivySettings.addResolver(repoResolver) ivySettings.setDefaultResolver(repoResolver.getName) @@ -899,7 +928,8 @@ private[deploy] object SparkSubmitUtils { } // retrieve all resolved dependencies ivy.retrieve(rr.getModuleDescriptor.getModuleRevisionId, - packagesDirectory.getAbsolutePath + File.separator + "[artifact](-[classifier]).[ext]", + packagesDirectory.getAbsolutePath + File.separator + + "[organization]_[artifact]-[revision].[ext]", retrieveOptions.setConfs(Array(ivyConfName))) System.setOut(sysOut) resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index c896842943f2b..c621b8fc86f94 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -241,8 +241,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S } private def validateKillArguments(): Unit = { - if (!master.startsWith("spark://")) { - SparkSubmit.printErrorAndExit("Killing submissions is only supported in standalone mode!") + if (!master.startsWith("spark://") && !master.startsWith("mesos://")) { + SparkSubmit.printErrorAndExit( + "Killing submissions is only supported in standalone or Mesos mode!") } if (submissionToKill == null) { SparkSubmit.printErrorAndExit("Please specify a submission to kill.") @@ -250,9 +251,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S } private def validateStatusRequestArguments(): Unit = { - if (!master.startsWith("spark://")) { + if (!master.startsWith("spark://") && !master.startsWith("mesos://")) { SparkSubmit.printErrorAndExit( - "Requesting submission statuses is only supported in standalone mode!") + "Requesting submission statuses is only supported in standalone or Mesos mode!") } if (submissionToRequestStatusFor == null) { SparkSubmit.printErrorAndExit("Please specify a submission to request status for.") @@ -485,6 +486,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | | Spark standalone with cluster deploy mode only: | --driver-cores NUM Cores for driver (Default: 1). + | + | Spark standalone or Mesos with cluster deploy mode only: | --supervise If given, restarts the driver on failure. | --kill SUBMISSION_ID If given, kills the driver specified. | --status SUBMISSION_ID If given, requests the status of the driver specified. diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index a94ebf6e53750..fb2cbbcccc54b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -333,8 +333,9 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis } try { val appListener = new ApplicationEventListener + val appCompleted = isApplicationCompleted(eventLog) bus.addListener(appListener) - bus.replay(logInput, logPath.toString) + bus.replay(logInput, logPath.toString, !appCompleted) new FsApplicationHistoryInfo( logPath.getName(), appListener.appId.getOrElse(logPath.getName()), @@ -343,7 +344,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis appListener.endTime.getOrElse(-1L), getModificationTime(eventLog).get, appListener.sparkUser.getOrElse(NOT_STARTED), - isApplicationCompleted(eventLog)) + appCompleted) } finally { logInput.close() } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index ff2eed6dee70a..1c21c179562ac 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -130,7 +130,7 @@ private[master] class Master( private val restServer = if (restServerEnabled) { val port = conf.getInt("spark.master.rest.port", 6066) - Some(new StandaloneRestServer(host, port, self, masterUrl, conf)) + Some(new StandaloneRestServer(host, port, conf, self, masterUrl)) } else { None } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala index 4823fd7cac0cb..52758d6a7c4be 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala @@ -23,6 +23,7 @@ import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.master.MasterMessages._ import org.apache.curator.framework.CuratorFramework import org.apache.curator.framework.recipes.leader.{LeaderLatchListener, LeaderLatch} +import org.apache.spark.deploy.SparkCuratorUtil private[master] class ZooKeeperLeaderElectionAgent(val masterActor: LeaderElectable, conf: SparkConf) extends LeaderLatchListener with LeaderElectionAgent with Logging { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index a285783f72000..80db6d474b5c1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -26,6 +26,7 @@ import org.apache.curator.framework.CuratorFramework import org.apache.zookeeper.CreateMode import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.deploy.SparkCuratorUtil private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serialization: Serialization) diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala new file mode 100644 index 0000000000000..5d4e5b899dfdc --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos + +import java.util.concurrent.CountDownLatch + +import org.apache.spark.deploy.mesos.ui.MesosClusterUI +import org.apache.spark.deploy.rest.mesos.MesosRestServer +import org.apache.spark.scheduler.cluster.mesos._ +import org.apache.spark.util.SignalLogger +import org.apache.spark.{Logging, SecurityManager, SparkConf} + +/* + * A dispatcher that is responsible for managing and launching drivers, and is intended to be + * used for Mesos cluster mode. The dispatcher is a long-running process started by the user in + * the cluster independently of Spark applications. + * It contains a [[MesosRestServer]] that listens for requests to submit drivers and a + * [[MesosClusterScheduler]] that processes these requests by negotiating with the Mesos master + * for resources. + * + * A typical new driver lifecycle is the following: + * - Driver submitted via spark-submit talking to the [[MesosRestServer]] + * - [[MesosRestServer]] queues the driver request to [[MesosClusterScheduler]] + * - [[MesosClusterScheduler]] gets resource offers and launches the drivers that are in queue + * + * This dispatcher supports both Mesos fine-grain or coarse-grain mode as the mode is configurable + * per driver launched. + * This class is needed since Mesos doesn't manage frameworks, so the dispatcher acts as + * a daemon to launch drivers as Mesos frameworks upon request. The dispatcher is also started and + * stopped by sbin/start-mesos-dispatcher and sbin/stop-mesos-dispatcher respectively. + */ +private[mesos] class MesosClusterDispatcher( + args: MesosClusterDispatcherArguments, + conf: SparkConf) + extends Logging { + + private val publicAddress = Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse(args.host) + private val recoveryMode = conf.get("spark.mesos.deploy.recoveryMode", "NONE").toUpperCase() + logInfo("Recovery mode in Mesos dispatcher set to: " + recoveryMode) + + private val engineFactory = recoveryMode match { + case "NONE" => new BlackHoleMesosClusterPersistenceEngineFactory + case "ZOOKEEPER" => new ZookeeperMesosClusterPersistenceEngineFactory(conf) + case _ => throw new IllegalArgumentException("Unsupported recovery mode: " + recoveryMode) + } + + private val scheduler = new MesosClusterScheduler(engineFactory, conf) + + private val server = new MesosRestServer(args.host, args.port, conf, scheduler) + private val webUi = new MesosClusterUI( + new SecurityManager(conf), + args.webUiPort, + conf, + publicAddress, + scheduler) + + private val shutdownLatch = new CountDownLatch(1) + + def start(): Unit = { + webUi.bind() + scheduler.frameworkUrl = webUi.activeWebUiUrl + scheduler.start() + server.start() + } + + def awaitShutdown(): Unit = { + shutdownLatch.await() + } + + def stop(): Unit = { + webUi.stop() + server.stop() + scheduler.stop() + shutdownLatch.countDown() + } +} + +private[mesos] object MesosClusterDispatcher extends Logging { + def main(args: Array[String]) { + SignalLogger.register(log) + val conf = new SparkConf + val dispatcherArgs = new MesosClusterDispatcherArguments(args, conf) + conf.setMaster(dispatcherArgs.masterUrl) + conf.setAppName(dispatcherArgs.name) + dispatcherArgs.zookeeperUrl.foreach { z => + conf.set("spark.mesos.deploy.recoveryMode", "ZOOKEEPER") + conf.set("spark.mesos.deploy.zookeeper.url", z) + } + val dispatcher = new MesosClusterDispatcher(dispatcherArgs, conf) + dispatcher.start() + val shutdownHook = new Thread() { + override def run() { + logInfo("Shutdown hook is shutting down dispatcher") + dispatcher.stop() + dispatcher.awaitShutdown() + } + } + Runtime.getRuntime.addShutdownHook(shutdownHook) + dispatcher.awaitShutdown() + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala new file mode 100644 index 0000000000000..894cb78d8591a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos + +import org.apache.spark.SparkConf +import org.apache.spark.util.{IntParam, Utils} + + +private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: SparkConf) { + var host = Utils.localHostName() + var port = 7077 + var name = "Spark Cluster" + var webUiPort = 8081 + var masterUrl: String = _ + var zookeeperUrl: Option[String] = None + var propertiesFile: String = _ + + parse(args.toList) + + propertiesFile = Utils.loadDefaultSparkProperties(conf, propertiesFile) + + private def parse(args: List[String]): Unit = args match { + case ("--host" | "-h") :: value :: tail => + Utils.checkHost(value, "Please use hostname " + value) + host = value + parse(tail) + + case ("--port" | "-p") :: IntParam(value) :: tail => + port = value + parse(tail) + + case ("--webui-port" | "-p") :: IntParam(value) :: tail => + webUiPort = value + parse(tail) + + case ("--zk" | "-z") :: value :: tail => + zookeeperUrl = Some(value) + parse(tail) + + case ("--master" | "-m") :: value :: tail => + if (!value.startsWith("mesos://")) { + System.err.println("Cluster dispatcher only supports mesos (uri begins with mesos://)") + System.exit(1) + } + masterUrl = value.stripPrefix("mesos://") + parse(tail) + + case ("--name") :: value :: tail => + name = value + parse(tail) + + case ("--properties-file") :: value :: tail => + propertiesFile = value + parse(tail) + + case ("--help") :: tail => + printUsageAndExit(0) + + case Nil => { + if (masterUrl == null) { + System.err.println("--master is required") + printUsageAndExit(1) + } + } + + case _ => + printUsageAndExit(1) + } + + private def printUsageAndExit(exitCode: Int): Unit = { + System.err.println( + "Usage: MesosClusterDispatcher [options]\n" + + "\n" + + "Options:\n" + + " -h HOST, --host HOST Hostname to listen on\n" + + " -p PORT, --port PORT Port to listen on (default: 7077)\n" + + " --webui-port WEBUI_PORT WebUI Port to listen on (default: 8081)\n" + + " --name NAME Framework name to show in Mesos UI\n" + + " -m --master MASTER URI for connecting to Mesos master\n" + + " -z --zk ZOOKEEPER Comma delimited URLs for connecting to \n" + + " Zookeeper for persistence\n" + + " --properties-file FILE Path to a custom Spark properties file.\n" + + " Default is conf/spark-defaults.conf.") + System.exit(exitCode) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosDriverDescription.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosDriverDescription.scala new file mode 100644 index 0000000000000..1948226800afe --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosDriverDescription.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos + +import java.util.Date + +import org.apache.spark.deploy.Command +import org.apache.spark.scheduler.cluster.mesos.MesosClusterRetryState + +/** + * Describes a Spark driver that is submitted from the + * [[org.apache.spark.deploy.rest.mesos.MesosRestServer]], to be launched by + * [[org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler]]. + * @param jarUrl URL to the application jar + * @param mem Amount of memory for the driver + * @param cores Number of cores for the driver + * @param supervise Supervise the driver for long running app + * @param command The command to launch the driver. + * @param schedulerProperties Extra properties to pass the Mesos scheduler + */ +private[spark] class MesosDriverDescription( + val name: String, + val jarUrl: String, + val mem: Int, + val cores: Double, + val supervise: Boolean, + val command: Command, + val schedulerProperties: Map[String, String], + val submissionId: String, + val submissionDate: Date, + val retryState: Option[MesosClusterRetryState] = None) + extends Serializable { + + def copy( + name: String = name, + jarUrl: String = jarUrl, + mem: Int = mem, + cores: Double = cores, + supervise: Boolean = supervise, + command: Command = command, + schedulerProperties: Map[String, String] = schedulerProperties, + submissionId: String = submissionId, + submissionDate: Date = submissionDate, + retryState: Option[MesosClusterRetryState] = retryState): MesosDriverDescription = { + new MesosDriverDescription(name, jarUrl, mem, cores, supervise, command, schedulerProperties, + submissionId, submissionDate, retryState) + } + + override def toString: String = s"MesosDriverDescription (${command.mainClass})" +} diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala new file mode 100644 index 0000000000000..7b2005e0f1237 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos.ui + +import javax.servlet.http.HttpServletRequest + +import scala.xml.Node + +import org.apache.mesos.Protos.TaskStatus +import org.apache.spark.deploy.mesos.MesosDriverDescription +import org.apache.spark.scheduler.cluster.mesos.MesosClusterSubmissionState +import org.apache.spark.ui.{UIUtils, WebUIPage} + +private[mesos] class MesosClusterPage(parent: MesosClusterUI) extends WebUIPage("") { + def render(request: HttpServletRequest): Seq[Node] = { + val state = parent.scheduler.getSchedulerState() + val queuedHeaders = Seq("Driver ID", "Submit Date", "Main Class", "Driver Resources") + val driverHeaders = queuedHeaders ++ + Seq("Start Date", "Mesos Slave ID", "State") + val retryHeaders = Seq("Driver ID", "Submit Date", "Description") ++ + Seq("Last Failed Status", "Next Retry Time", "Attempt Count") + val queuedTable = UIUtils.listingTable(queuedHeaders, queuedRow, state.queuedDrivers) + val launchedTable = UIUtils.listingTable(driverHeaders, driverRow, state.launchedDrivers) + val finishedTable = UIUtils.listingTable(driverHeaders, driverRow, state.finishedDrivers) + val retryTable = UIUtils.listingTable(retryHeaders, retryRow, state.pendingRetryDrivers) + val content = +

Mesos Framework ID: {state.frameworkId}

+
+
+

Queued Drivers:

+ {queuedTable} +

Launched Drivers:

+ {launchedTable} +

Finished Drivers:

+ {finishedTable} +

Supervise drivers waiting for retry:

+ {retryTable} +
+
; + UIUtils.basicSparkPage(content, "Spark Drivers for Mesos cluster") + } + + private def queuedRow(submission: MesosDriverDescription): Seq[Node] = { + + {submission.submissionId} + {submission.submissionDate} + {submission.command.mainClass} + cpus: {submission.cores}, mem: {submission.mem} + + } + + private def driverRow(state: MesosClusterSubmissionState): Seq[Node] = { + + {state.driverDescription.submissionId} + {state.driverDescription.submissionDate} + {state.driverDescription.command.mainClass} + cpus: {state.driverDescription.cores}, mem: {state.driverDescription.mem} + {state.startDate} + {state.slaveId.getValue} + {stateString(state.mesosTaskStatus)} + + } + + private def retryRow(submission: MesosDriverDescription): Seq[Node] = { + + {submission.submissionId} + {submission.submissionDate} + {submission.command.mainClass} + {submission.retryState.get.lastFailureStatus} + {submission.retryState.get.nextRetry} + {submission.retryState.get.retries} + + } + + private def stateString(status: Option[TaskStatus]): String = { + if (status.isEmpty) { + return "" + } + val sb = new StringBuilder + val s = status.get + sb.append(s"State: ${s.getState}") + if (status.get.hasMessage) { + sb.append(s", Message: ${s.getMessage}") + } + if (status.get.hasHealthy) { + sb.append(s", Healthy: ${s.getHealthy}") + } + if (status.get.hasSource) { + sb.append(s", Source: ${s.getSource}") + } + if (status.get.hasReason) { + sb.append(s", Reason: ${s.getReason}") + } + if (status.get.hasTimestamp) { + sb.append(s", Time: ${s.getTimestamp}") + } + sb.toString() + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala new file mode 100644 index 0000000000000..4865d46dbc4ab --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos.ui + +import org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler +import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.ui.JettyUtils._ +import org.apache.spark.ui.{SparkUI, WebUI} + +/** + * UI that displays driver results from the [[org.apache.spark.deploy.mesos.MesosClusterDispatcher]] + */ +private[spark] class MesosClusterUI( + securityManager: SecurityManager, + port: Int, + conf: SparkConf, + dispatcherPublicAddress: String, + val scheduler: MesosClusterScheduler) + extends WebUI(securityManager, port, conf) { + + initialize() + + def activeWebUiUrl: String = "http://" + dispatcherPublicAddress + ":" + boundPort + + override def initialize() { + attachPage(new MesosClusterPage(this)) + attachHandler(createStaticHandler(MesosClusterUI.STATIC_RESOURCE_DIR, "/static")) + } +} + +private object MesosClusterUI { + val STATIC_RESOURCE_DIR = SparkUI.STATIC_RESOURCE_DIR +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala similarity index 93% rename from core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala rename to core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index b8fd406fb6f9a..307cebfb4bd09 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -30,9 +30,7 @@ import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion} import org.apache.spark.util.Utils /** - * A client that submits applications to the standalone Master using a REST protocol. - * This client is intended to communicate with the [[StandaloneRestServer]] and is - * currently used for cluster mode only. + * A client that submits applications to a [[RestSubmissionServer]]. * * In protocol version v1, the REST URL takes the form http://[host:port]/v1/submissions/[action], * where [action] can be one of create, kill, or status. Each type of request is represented in @@ -53,8 +51,10 @@ import org.apache.spark.util.Utils * implementation of this client can use that information to retry using the version specified * by the server. */ -private[deploy] class StandaloneRestClient extends Logging { - import StandaloneRestClient._ +private[spark] class RestSubmissionClient extends Logging { + import RestSubmissionClient._ + + private val supportedMasterPrefixes = Seq("spark://", "mesos://") /** * Submit an application specified by the parameters in the provided request. @@ -62,7 +62,7 @@ private[deploy] class StandaloneRestClient extends Logging { * If the submission was successful, poll the status of the submission and report * it to the user. Otherwise, report the error message provided by the server. */ - private[rest] def createSubmission( + def createSubmission( master: String, request: CreateSubmissionRequest): SubmitRestProtocolResponse = { logInfo(s"Submitting a request to launch an application in $master.") @@ -107,7 +107,7 @@ private[deploy] class StandaloneRestClient extends Logging { } /** Construct a message that captures the specified parameters for submitting an application. */ - private[rest] def constructSubmitRequest( + def constructSubmitRequest( appResource: String, mainClass: String, appArgs: Array[String], @@ -219,14 +219,23 @@ private[deploy] class StandaloneRestClient extends Logging { /** Return the base URL for communicating with the server, including the protocol version. */ private def getBaseUrl(master: String): String = { - val masterUrl = master.stripPrefix("spark://").stripSuffix("/") + var masterUrl = master + supportedMasterPrefixes.foreach { prefix => + if (master.startsWith(prefix)) { + masterUrl = master.stripPrefix(prefix) + } + } + masterUrl = masterUrl.stripSuffix("/") s"http://$masterUrl/$PROTOCOL_VERSION/submissions" } /** Throw an exception if this is not standalone mode. */ private def validateMaster(master: String): Unit = { - if (!master.startsWith("spark://")) { - throw new IllegalArgumentException("This REST client is only supported in standalone mode.") + val valid = supportedMasterPrefixes.exists { prefix => master.startsWith(prefix) } + if (!valid) { + throw new IllegalArgumentException( + "This REST client only supports master URLs that start with " + + "one of the following: " + supportedMasterPrefixes.mkString(",")) } } @@ -295,7 +304,7 @@ private[deploy] class StandaloneRestClient extends Logging { } } -private[rest] object StandaloneRestClient { +private[spark] object RestSubmissionClient { private val REPORT_DRIVER_STATUS_INTERVAL = 1000 private val REPORT_DRIVER_STATUS_MAX_TRIES = 10 val PROTOCOL_VERSION = "v1" @@ -315,7 +324,7 @@ private[rest] object StandaloneRestClient { } val sparkProperties = conf.getAll.toMap val environmentVariables = env.filter { case (k, _) => k.startsWith("SPARK_") } - val client = new StandaloneRestClient + val client = new RestSubmissionClient val submitRequest = client.constructSubmitRequest( appResource, mainClass, appArgs, sparkProperties, environmentVariables) client.createSubmission(master, submitRequest) @@ -323,7 +332,7 @@ private[rest] object StandaloneRestClient { def main(args: Array[String]): Unit = { if (args.size < 2) { - sys.error("Usage: StandaloneRestClient [app resource] [main class] [app args*]") + sys.error("Usage: RestSubmissionClient [app resource] [main class] [app args*]") sys.exit(1) } val appResource = args(0) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala new file mode 100644 index 0000000000000..2e78d03e5c0cc --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala @@ -0,0 +1,318 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import java.net.InetSocketAddress +import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} + +import scala.io.Source +import com.fasterxml.jackson.core.JsonProcessingException +import org.eclipse.jetty.server.Server +import org.eclipse.jetty.servlet.{ServletHolder, ServletContextHandler} +import org.eclipse.jetty.util.thread.QueuedThreadPool +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion} +import org.apache.spark.util.Utils + +/** + * A server that responds to requests submitted by the [[RestSubmissionClient]]. + * + * This server responds with different HTTP codes depending on the situation: + * 200 OK - Request was processed successfully + * 400 BAD REQUEST - Request was malformed, not successfully validated, or of unexpected type + * 468 UNKNOWN PROTOCOL VERSION - Request specified a protocol this server does not understand + * 500 INTERNAL SERVER ERROR - Server throws an exception internally while processing the request + * + * The server always includes a JSON representation of the relevant [[SubmitRestProtocolResponse]] + * in the HTTP body. If an error occurs, however, the server will include an [[ErrorResponse]] + * instead of the one expected by the client. If the construction of this error response itself + * fails, the response will consist of an empty body with a response code that indicates internal + * server error. + */ +private[spark] abstract class RestSubmissionServer( + val host: String, + val requestedPort: Int, + val masterConf: SparkConf) extends Logging { + protected val submitRequestServlet: SubmitRequestServlet + protected val killRequestServlet: KillRequestServlet + protected val statusRequestServlet: StatusRequestServlet + + private var _server: Option[Server] = None + + // A mapping from URL prefixes to servlets that serve them. Exposed for testing. + protected val baseContext = s"/${RestSubmissionServer.PROTOCOL_VERSION}/submissions" + protected lazy val contextToServlet = Map[String, RestServlet]( + s"$baseContext/create/*" -> submitRequestServlet, + s"$baseContext/kill/*" -> killRequestServlet, + s"$baseContext/status/*" -> statusRequestServlet, + "/*" -> new ErrorServlet // default handler + ) + + /** Start the server and return the bound port. */ + def start(): Int = { + val (server, boundPort) = Utils.startServiceOnPort[Server](requestedPort, doStart, masterConf) + _server = Some(server) + logInfo(s"Started REST server for submitting applications on port $boundPort") + boundPort + } + + /** + * Map the servlets to their corresponding contexts and attach them to a server. + * Return a 2-tuple of the started server and the bound port. + */ + private def doStart(startPort: Int): (Server, Int) = { + val server = new Server(new InetSocketAddress(host, startPort)) + val threadPool = new QueuedThreadPool + threadPool.setDaemon(true) + server.setThreadPool(threadPool) + val mainHandler = new ServletContextHandler + mainHandler.setContextPath("/") + contextToServlet.foreach { case (prefix, servlet) => + mainHandler.addServlet(new ServletHolder(servlet), prefix) + } + server.setHandler(mainHandler) + server.start() + val boundPort = server.getConnectors()(0).getLocalPort + (server, boundPort) + } + + def stop(): Unit = { + _server.foreach(_.stop()) + } +} + +private[rest] object RestSubmissionServer { + val PROTOCOL_VERSION = RestSubmissionClient.PROTOCOL_VERSION + val SC_UNKNOWN_PROTOCOL_VERSION = 468 +} + +/** + * An abstract servlet for handling requests passed to the [[RestSubmissionServer]]. + */ +private[rest] abstract class RestServlet extends HttpServlet with Logging { + + /** + * Serialize the given response message to JSON and send it through the response servlet. + * This validates the response before sending it to ensure it is properly constructed. + */ + protected def sendResponse( + responseMessage: SubmitRestProtocolResponse, + responseServlet: HttpServletResponse): Unit = { + val message = validateResponse(responseMessage, responseServlet) + responseServlet.setContentType("application/json") + responseServlet.setCharacterEncoding("utf-8") + responseServlet.getWriter.write(message.toJson) + } + + /** + * Return any fields in the client request message that the server does not know about. + * + * The mechanism for this is to reconstruct the JSON on the server side and compare the + * diff between this JSON and the one generated on the client side. Any fields that are + * only in the client JSON are treated as unexpected. + */ + protected def findUnknownFields( + requestJson: String, + requestMessage: SubmitRestProtocolMessage): Array[String] = { + val clientSideJson = parse(requestJson) + val serverSideJson = parse(requestMessage.toJson) + val Diff(_, _, unknown) = clientSideJson.diff(serverSideJson) + unknown match { + case j: JObject => j.obj.map { case (k, _) => k }.toArray + case _ => Array.empty[String] // No difference + } + } + + /** Return a human readable String representation of the exception. */ + protected def formatException(e: Throwable): String = { + val stackTraceString = e.getStackTrace.map { "\t" + _ }.mkString("\n") + s"$e\n$stackTraceString" + } + + /** Construct an error message to signal the fact that an exception has been thrown. */ + protected def handleError(message: String): ErrorResponse = { + val e = new ErrorResponse + e.serverSparkVersion = sparkVersion + e.message = message + e + } + + /** + * Parse a submission ID from the relative path, assuming it is the first part of the path. + * For instance, we expect the path to take the form /[submission ID]/maybe/something/else. + * The returned submission ID cannot be empty. If the path is unexpected, return None. + */ + protected def parseSubmissionId(path: String): Option[String] = { + if (path == null || path.isEmpty) { + None + } else { + path.stripPrefix("/").split("/").headOption.filter(_.nonEmpty) + } + } + + /** + * Validate the response to ensure that it is correctly constructed. + * + * If it is, simply return the message as is. Otherwise, return an error response instead + * to propagate the exception back to the client and set the appropriate error code. + */ + private def validateResponse( + responseMessage: SubmitRestProtocolResponse, + responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { + try { + responseMessage.validate() + responseMessage + } catch { + case e: Exception => + responseServlet.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR) + handleError("Internal server error: " + formatException(e)) + } + } +} + +/** + * A servlet for handling kill requests passed to the [[RestSubmissionServer]]. + */ +private[rest] abstract class KillRequestServlet extends RestServlet { + + /** + * If a submission ID is specified in the URL, have the Master kill the corresponding + * driver and return an appropriate response to the client. Otherwise, return error. + */ + protected override def doPost( + request: HttpServletRequest, + response: HttpServletResponse): Unit = { + val submissionId = parseSubmissionId(request.getPathInfo) + val responseMessage = submissionId.map(handleKill).getOrElse { + response.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError("Submission ID is missing in kill request.") + } + sendResponse(responseMessage, response) + } + + protected def handleKill(submissionId: String): KillSubmissionResponse +} + +/** + * A servlet for handling status requests passed to the [[RestSubmissionServer]]. + */ +private[rest] abstract class StatusRequestServlet extends RestServlet { + + /** + * If a submission ID is specified in the URL, request the status of the corresponding + * driver from the Master and include it in the response. Otherwise, return error. + */ + protected override def doGet( + request: HttpServletRequest, + response: HttpServletResponse): Unit = { + val submissionId = parseSubmissionId(request.getPathInfo) + val responseMessage = submissionId.map(handleStatus).getOrElse { + response.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError("Submission ID is missing in status request.") + } + sendResponse(responseMessage, response) + } + + protected def handleStatus(submissionId: String): SubmissionStatusResponse +} + +/** + * A servlet for handling submit requests passed to the [[RestSubmissionServer]]. + */ +private[rest] abstract class SubmitRequestServlet extends RestServlet { + + /** + * Submit an application to the Master with parameters specified in the request. + * + * The request is assumed to be a [[SubmitRestProtocolRequest]] in the form of JSON. + * If the request is successfully processed, return an appropriate response to the + * client indicating so. Otherwise, return error instead. + */ + protected override def doPost( + requestServlet: HttpServletRequest, + responseServlet: HttpServletResponse): Unit = { + val responseMessage = + try { + val requestMessageJson = Source.fromInputStream(requestServlet.getInputStream).mkString + val requestMessage = SubmitRestProtocolMessage.fromJson(requestMessageJson) + // The response should have already been validated on the client. + // In case this is not true, validate it ourselves to avoid potential NPEs. + requestMessage.validate() + handleSubmit(requestMessageJson, requestMessage, responseServlet) + } catch { + // The client failed to provide a valid JSON, so this is not our fault + case e @ (_: JsonProcessingException | _: SubmitRestProtocolException) => + responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError("Malformed request: " + formatException(e)) + } + sendResponse(responseMessage, responseServlet) + } + + protected def handleSubmit( + requestMessageJson: String, + requestMessage: SubmitRestProtocolMessage, + responseServlet: HttpServletResponse): SubmitRestProtocolResponse +} + +/** + * A default servlet that handles error cases that are not captured by other servlets. + */ +private class ErrorServlet extends RestServlet { + private val serverVersion = RestSubmissionServer.PROTOCOL_VERSION + + /** Service a faulty request by returning an appropriate error message to the client. */ + protected override def service( + request: HttpServletRequest, + response: HttpServletResponse): Unit = { + val path = request.getPathInfo + val parts = path.stripPrefix("/").split("/").filter(_.nonEmpty).toList + var versionMismatch = false + var msg = + parts match { + case Nil => + // http://host:port/ + "Missing protocol version." + case `serverVersion` :: Nil => + // http://host:port/correct-version + "Missing the /submissions prefix." + case `serverVersion` :: "submissions" :: tail => + // http://host:port/correct-version/submissions/* + "Missing an action: please specify one of /create, /kill, or /status." + case unknownVersion :: tail => + // http://host:port/unknown-version/* + versionMismatch = true + s"Unknown protocol version '$unknownVersion'." + case _ => + // never reached + s"Malformed path $path." + } + msg += s" Please submit requests through http://[host]:[port]/$serverVersion/submissions/..." + val error = handleError(msg) + // If there is a version mismatch, include the highest protocol version that + // this server supports in case the client wants to retry with our version + if (versionMismatch) { + error.highestProtocolVersion = serverVersion + response.setStatus(RestSubmissionServer.SC_UNKNOWN_PROTOCOL_VERSION) + } else { + response.setStatus(HttpServletResponse.SC_BAD_REQUEST) + } + sendResponse(error, response) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index 2d6b8d4204795..502b9bb701ccf 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -18,26 +18,16 @@ package org.apache.spark.deploy.rest import java.io.File -import java.net.InetSocketAddress -import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} - -import scala.io.Source +import javax.servlet.http.HttpServletResponse import akka.actor.ActorRef -import com.fasterxml.jackson.core.JsonProcessingException -import org.eclipse.jetty.server.Server -import org.eclipse.jetty.servlet.{ServletHolder, ServletContextHandler} -import org.eclipse.jetty.util.thread.QueuedThreadPool -import org.json4s._ -import org.json4s.jackson.JsonMethods._ - -import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion} -import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils} -import org.apache.spark.deploy.{Command, DeployMessages, DriverDescription} import org.apache.spark.deploy.ClientArguments._ +import org.apache.spark.deploy.{Command, DeployMessages, DriverDescription} +import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils} +import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} /** - * A server that responds to requests submitted by the [[StandaloneRestClient]]. + * A server that responds to requests submitted by the [[RestSubmissionClient]]. * This is intended to be embedded in the standalone Master and used in cluster mode only. * * This server responds with different HTTP codes depending on the situation: @@ -54,173 +44,31 @@ import org.apache.spark.deploy.ClientArguments._ * * @param host the address this server should bind to * @param requestedPort the port this server will attempt to bind to + * @param masterConf the conf used by the Master * @param masterActor reference to the Master actor to which requests can be sent * @param masterUrl the URL of the Master new drivers will attempt to connect to - * @param masterConf the conf used by the Master */ private[deploy] class StandaloneRestServer( host: String, requestedPort: Int, + masterConf: SparkConf, masterActor: ActorRef, - masterUrl: String, - masterConf: SparkConf) - extends Logging { - - import StandaloneRestServer._ - - private var _server: Option[Server] = None - - // A mapping from URL prefixes to servlets that serve them. Exposed for testing. - protected val baseContext = s"/$PROTOCOL_VERSION/submissions" - protected val contextToServlet = Map[String, StandaloneRestServlet]( - s"$baseContext/create/*" -> new SubmitRequestServlet(masterActor, masterUrl, masterConf), - s"$baseContext/kill/*" -> new KillRequestServlet(masterActor, masterConf), - s"$baseContext/status/*" -> new StatusRequestServlet(masterActor, masterConf), - "/*" -> new ErrorServlet // default handler - ) - - /** Start the server and return the bound port. */ - def start(): Int = { - val (server, boundPort) = Utils.startServiceOnPort[Server](requestedPort, doStart, masterConf) - _server = Some(server) - logInfo(s"Started REST server for submitting applications on port $boundPort") - boundPort - } - - /** - * Map the servlets to their corresponding contexts and attach them to a server. - * Return a 2-tuple of the started server and the bound port. - */ - private def doStart(startPort: Int): (Server, Int) = { - val server = new Server(new InetSocketAddress(host, startPort)) - val threadPool = new QueuedThreadPool - threadPool.setDaemon(true) - server.setThreadPool(threadPool) - val mainHandler = new ServletContextHandler - mainHandler.setContextPath("/") - contextToServlet.foreach { case (prefix, servlet) => - mainHandler.addServlet(new ServletHolder(servlet), prefix) - } - server.setHandler(mainHandler) - server.start() - val boundPort = server.getConnectors()(0).getLocalPort - (server, boundPort) - } - - def stop(): Unit = { - _server.foreach(_.stop()) - } -} - -private[rest] object StandaloneRestServer { - val PROTOCOL_VERSION = StandaloneRestClient.PROTOCOL_VERSION - val SC_UNKNOWN_PROTOCOL_VERSION = 468 -} - -/** - * An abstract servlet for handling requests passed to the [[StandaloneRestServer]]. - */ -private[rest] abstract class StandaloneRestServlet extends HttpServlet with Logging { - - /** - * Serialize the given response message to JSON and send it through the response servlet. - * This validates the response before sending it to ensure it is properly constructed. - */ - protected def sendResponse( - responseMessage: SubmitRestProtocolResponse, - responseServlet: HttpServletResponse): Unit = { - val message = validateResponse(responseMessage, responseServlet) - responseServlet.setContentType("application/json") - responseServlet.setCharacterEncoding("utf-8") - responseServlet.getWriter.write(message.toJson) - } - - /** - * Return any fields in the client request message that the server does not know about. - * - * The mechanism for this is to reconstruct the JSON on the server side and compare the - * diff between this JSON and the one generated on the client side. Any fields that are - * only in the client JSON are treated as unexpected. - */ - protected def findUnknownFields( - requestJson: String, - requestMessage: SubmitRestProtocolMessage): Array[String] = { - val clientSideJson = parse(requestJson) - val serverSideJson = parse(requestMessage.toJson) - val Diff(_, _, unknown) = clientSideJson.diff(serverSideJson) - unknown match { - case j: JObject => j.obj.map { case (k, _) => k }.toArray - case _ => Array.empty[String] // No difference - } - } - - /** Return a human readable String representation of the exception. */ - protected def formatException(e: Throwable): String = { - val stackTraceString = e.getStackTrace.map { "\t" + _ }.mkString("\n") - s"$e\n$stackTraceString" - } - - /** Construct an error message to signal the fact that an exception has been thrown. */ - protected def handleError(message: String): ErrorResponse = { - val e = new ErrorResponse - e.serverSparkVersion = sparkVersion - e.message = message - e - } - - /** - * Parse a submission ID from the relative path, assuming it is the first part of the path. - * For instance, we expect the path to take the form /[submission ID]/maybe/something/else. - * The returned submission ID cannot be empty. If the path is unexpected, return None. - */ - protected def parseSubmissionId(path: String): Option[String] = { - if (path == null || path.isEmpty) { - None - } else { - path.stripPrefix("/").split("/").headOption.filter(_.nonEmpty) - } - } - - /** - * Validate the response to ensure that it is correctly constructed. - * - * If it is, simply return the message as is. Otherwise, return an error response instead - * to propagate the exception back to the client and set the appropriate error code. - */ - private def validateResponse( - responseMessage: SubmitRestProtocolResponse, - responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { - try { - responseMessage.validate() - responseMessage - } catch { - case e: Exception => - responseServlet.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR) - handleError("Internal server error: " + formatException(e)) - } - } + masterUrl: String) + extends RestSubmissionServer(host, requestedPort, masterConf) { + + protected override val submitRequestServlet = + new StandaloneSubmitRequestServlet(masterActor, masterUrl, masterConf) + protected override val killRequestServlet = + new StandaloneKillRequestServlet(masterActor, masterConf) + protected override val statusRequestServlet = + new StandaloneStatusRequestServlet(masterActor, masterConf) } /** * A servlet for handling kill requests passed to the [[StandaloneRestServer]]. */ -private[rest] class KillRequestServlet(masterActor: ActorRef, conf: SparkConf) - extends StandaloneRestServlet { - - /** - * If a submission ID is specified in the URL, have the Master kill the corresponding - * driver and return an appropriate response to the client. Otherwise, return error. - */ - protected override def doPost( - request: HttpServletRequest, - response: HttpServletResponse): Unit = { - val submissionId = parseSubmissionId(request.getPathInfo) - val responseMessage = submissionId.map(handleKill).getOrElse { - response.setStatus(HttpServletResponse.SC_BAD_REQUEST) - handleError("Submission ID is missing in kill request.") - } - sendResponse(responseMessage, response) - } +private[rest] class StandaloneKillRequestServlet(masterActor: ActorRef, conf: SparkConf) + extends KillRequestServlet { protected def handleKill(submissionId: String): KillSubmissionResponse = { val askTimeout = RpcUtils.askTimeout(conf) @@ -238,23 +86,8 @@ private[rest] class KillRequestServlet(masterActor: ActorRef, conf: SparkConf) /** * A servlet for handling status requests passed to the [[StandaloneRestServer]]. */ -private[rest] class StatusRequestServlet(masterActor: ActorRef, conf: SparkConf) - extends StandaloneRestServlet { - - /** - * If a submission ID is specified in the URL, request the status of the corresponding - * driver from the Master and include it in the response. Otherwise, return error. - */ - protected override def doGet( - request: HttpServletRequest, - response: HttpServletResponse): Unit = { - val submissionId = parseSubmissionId(request.getPathInfo) - val responseMessage = submissionId.map(handleStatus).getOrElse { - response.setStatus(HttpServletResponse.SC_BAD_REQUEST) - handleError("Submission ID is missing in status request.") - } - sendResponse(responseMessage, response) - } +private[rest] class StandaloneStatusRequestServlet(masterActor: ActorRef, conf: SparkConf) + extends StatusRequestServlet { protected def handleStatus(submissionId: String): SubmissionStatusResponse = { val askTimeout = RpcUtils.askTimeout(conf) @@ -276,71 +109,11 @@ private[rest] class StatusRequestServlet(masterActor: ActorRef, conf: SparkConf) /** * A servlet for handling submit requests passed to the [[StandaloneRestServer]]. */ -private[rest] class SubmitRequestServlet( +private[rest] class StandaloneSubmitRequestServlet( masterActor: ActorRef, masterUrl: String, conf: SparkConf) - extends StandaloneRestServlet { - - /** - * Submit an application to the Master with parameters specified in the request. - * - * The request is assumed to be a [[SubmitRestProtocolRequest]] in the form of JSON. - * If the request is successfully processed, return an appropriate response to the - * client indicating so. Otherwise, return error instead. - */ - protected override def doPost( - requestServlet: HttpServletRequest, - responseServlet: HttpServletResponse): Unit = { - val responseMessage = - try { - val requestMessageJson = Source.fromInputStream(requestServlet.getInputStream).mkString - val requestMessage = SubmitRestProtocolMessage.fromJson(requestMessageJson) - // The response should have already been validated on the client. - // In case this is not true, validate it ourselves to avoid potential NPEs. - requestMessage.validate() - handleSubmit(requestMessageJson, requestMessage, responseServlet) - } catch { - // The client failed to provide a valid JSON, so this is not our fault - case e @ (_: JsonProcessingException | _: SubmitRestProtocolException) => - responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST) - handleError("Malformed request: " + formatException(e)) - } - sendResponse(responseMessage, responseServlet) - } - - /** - * Handle the submit request and construct an appropriate response to return to the client. - * - * This assumes that the request message is already successfully validated. - * If the request message is not of the expected type, return error to the client. - */ - private def handleSubmit( - requestMessageJson: String, - requestMessage: SubmitRestProtocolMessage, - responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { - requestMessage match { - case submitRequest: CreateSubmissionRequest => - val askTimeout = RpcUtils.askTimeout(conf) - val driverDescription = buildDriverDescription(submitRequest) - val response = AkkaUtils.askWithReply[DeployMessages.SubmitDriverResponse]( - DeployMessages.RequestSubmitDriver(driverDescription), masterActor, askTimeout) - val submitResponse = new CreateSubmissionResponse - submitResponse.serverSparkVersion = sparkVersion - submitResponse.message = response.message - submitResponse.success = response.success - submitResponse.submissionId = response.driverId.orNull - val unknownFields = findUnknownFields(requestMessageJson, requestMessage) - if (unknownFields.nonEmpty) { - // If there are fields that the server does not know about, warn the client - submitResponse.unknownFields = unknownFields - } - submitResponse - case unexpected => - responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST) - handleError(s"Received message of unexpected type ${unexpected.messageType}.") - } - } + extends SubmitRequestServlet { /** * Build a driver description from the fields specified in the submit request. @@ -389,50 +162,37 @@ private[rest] class SubmitRequestServlet( new DriverDescription( appResource, actualDriverMemory, actualDriverCores, actualSuperviseDriver, command) } -} -/** - * A default servlet that handles error cases that are not captured by other servlets. - */ -private class ErrorServlet extends StandaloneRestServlet { - private val serverVersion = StandaloneRestServer.PROTOCOL_VERSION - - /** Service a faulty request by returning an appropriate error message to the client. */ - protected override def service( - request: HttpServletRequest, - response: HttpServletResponse): Unit = { - val path = request.getPathInfo - val parts = path.stripPrefix("/").split("/").filter(_.nonEmpty).toList - var versionMismatch = false - var msg = - parts match { - case Nil => - // http://host:port/ - "Missing protocol version." - case `serverVersion` :: Nil => - // http://host:port/correct-version - "Missing the /submissions prefix." - case `serverVersion` :: "submissions" :: tail => - // http://host:port/correct-version/submissions/* - "Missing an action: please specify one of /create, /kill, or /status." - case unknownVersion :: tail => - // http://host:port/unknown-version/* - versionMismatch = true - s"Unknown protocol version '$unknownVersion'." - case _ => - // never reached - s"Malformed path $path." - } - msg += s" Please submit requests through http://[host]:[port]/$serverVersion/submissions/..." - val error = handleError(msg) - // If there is a version mismatch, include the highest protocol version that - // this server supports in case the client wants to retry with our version - if (versionMismatch) { - error.highestProtocolVersion = serverVersion - response.setStatus(StandaloneRestServer.SC_UNKNOWN_PROTOCOL_VERSION) - } else { - response.setStatus(HttpServletResponse.SC_BAD_REQUEST) + /** + * Handle the submit request and construct an appropriate response to return to the client. + * + * This assumes that the request message is already successfully validated. + * If the request message is not of the expected type, return error to the client. + */ + protected override def handleSubmit( + requestMessageJson: String, + requestMessage: SubmitRestProtocolMessage, + responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { + requestMessage match { + case submitRequest: CreateSubmissionRequest => + val askTimeout = RpcUtils.askTimeout(conf) + val driverDescription = buildDriverDescription(submitRequest) + val response = AkkaUtils.askWithReply[DeployMessages.SubmitDriverResponse]( + DeployMessages.RequestSubmitDriver(driverDescription), masterActor, askTimeout) + val submitResponse = new CreateSubmissionResponse + submitResponse.serverSparkVersion = sparkVersion + submitResponse.message = response.message + submitResponse.success = response.success + submitResponse.submissionId = response.driverId.orNull + val unknownFields = findUnknownFields(requestMessageJson, requestMessage) + if (unknownFields.nonEmpty) { + // If there are fields that the server does not know about, warn the client + submitResponse.unknownFields = unknownFields + } + submitResponse + case unexpected => + responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError(s"Received message of unexpected type ${unexpected.messageType}.") } - sendResponse(error, response) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala index d80abdf15fb34..0d50a768942ed 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala @@ -61,7 +61,7 @@ private[rest] class CreateSubmissionRequest extends SubmitRestProtocolRequest { assertProperty[Boolean](key, "boolean", _.toBoolean) private def assertPropertyIsNumeric(key: String): Unit = - assertProperty[Int](key, "numeric", _.toInt) + assertProperty[Double](key, "numeric", _.toDouble) private def assertPropertyIsMemory(key: String): Unit = assertProperty[Int](key, "memory", Utils.memoryStringToMb) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala index 8fde8c142a4c1..0e226ee294cab 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala @@ -35,7 +35,7 @@ private[rest] abstract class SubmitRestProtocolResponse extends SubmitRestProtoc /** * A response to a [[CreateSubmissionRequest]] in the REST application submission protocol. */ -private[rest] class CreateSubmissionResponse extends SubmitRestProtocolResponse { +private[spark] class CreateSubmissionResponse extends SubmitRestProtocolResponse { var submissionId: String = null protected override def doValidate(): Unit = { super.doValidate() @@ -46,7 +46,7 @@ private[rest] class CreateSubmissionResponse extends SubmitRestProtocolResponse /** * A response to a kill request in the REST application submission protocol. */ -private[rest] class KillSubmissionResponse extends SubmitRestProtocolResponse { +private[spark] class KillSubmissionResponse extends SubmitRestProtocolResponse { var submissionId: String = null protected override def doValidate(): Unit = { super.doValidate() @@ -58,7 +58,7 @@ private[rest] class KillSubmissionResponse extends SubmitRestProtocolResponse { /** * A response to a status request in the REST application submission protocol. */ -private[rest] class SubmissionStatusResponse extends SubmitRestProtocolResponse { +private[spark] class SubmissionStatusResponse extends SubmitRestProtocolResponse { var submissionId: String = null var driverState: String = null var workerId: String = null diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala new file mode 100644 index 0000000000000..fd17a980c9319 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest.mesos + +import java.io.File +import java.text.SimpleDateFormat +import java.util.Date +import java.util.concurrent.atomic.AtomicLong +import javax.servlet.http.HttpServletResponse + +import org.apache.spark.deploy.Command +import org.apache.spark.deploy.mesos.MesosDriverDescription +import org.apache.spark.deploy.rest._ +import org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler +import org.apache.spark.util.Utils +import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} + + +/** + * A server that responds to requests submitted by the [[RestSubmissionClient]]. + * All requests are forwarded to + * [[org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler]]. + * This is intended to be used in Mesos cluster mode only. + * For more details about the REST submission please refer to [[RestSubmissionServer]] javadocs. + */ +private[spark] class MesosRestServer( + host: String, + requestedPort: Int, + masterConf: SparkConf, + scheduler: MesosClusterScheduler) + extends RestSubmissionServer(host, requestedPort, masterConf) { + + protected override val submitRequestServlet = + new MesosSubmitRequestServlet(scheduler, masterConf) + protected override val killRequestServlet = + new MesosKillRequestServlet(scheduler, masterConf) + protected override val statusRequestServlet = + new MesosStatusRequestServlet(scheduler, masterConf) +} + +private[deploy] class MesosSubmitRequestServlet( + scheduler: MesosClusterScheduler, + conf: SparkConf) + extends SubmitRequestServlet { + + private val DEFAULT_SUPERVISE = false + private val DEFAULT_MEMORY = 512 // mb + private val DEFAULT_CORES = 1.0 + + private val nextDriverNumber = new AtomicLong(0) + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs + private def newDriverId(submitDate: Date): String = { + "driver-%s-%04d".format( + createDateFormat.format(submitDate), nextDriverNumber.incrementAndGet()) + } + + /** + * Build a driver description from the fields specified in the submit request. + * + * This involves constructing a command that launches a mesos framework for the job. + * This does not currently consider fields used by python applications since python + * is not supported in mesos cluster mode yet. + */ + private def buildDriverDescription(request: CreateSubmissionRequest): MesosDriverDescription = { + // Required fields, including the main class because python is not yet supported + val appResource = Option(request.appResource).getOrElse { + throw new SubmitRestMissingFieldException("Application jar is missing.") + } + val mainClass = Option(request.mainClass).getOrElse { + throw new SubmitRestMissingFieldException("Main class is missing.") + } + + // Optional fields + val sparkProperties = request.sparkProperties + val driverExtraJavaOptions = sparkProperties.get("spark.driver.extraJavaOptions") + val driverExtraClassPath = sparkProperties.get("spark.driver.extraClassPath") + val driverExtraLibraryPath = sparkProperties.get("spark.driver.extraLibraryPath") + val superviseDriver = sparkProperties.get("spark.driver.supervise") + val driverMemory = sparkProperties.get("spark.driver.memory") + val driverCores = sparkProperties.get("spark.driver.cores") + val appArgs = request.appArgs + val environmentVariables = request.environmentVariables + val name = request.sparkProperties.get("spark.app.name").getOrElse(mainClass) + + // Construct driver description + val conf = new SparkConf(false).setAll(sparkProperties) + val extraClassPath = driverExtraClassPath.toSeq.flatMap(_.split(File.pathSeparator)) + val extraLibraryPath = driverExtraLibraryPath.toSeq.flatMap(_.split(File.pathSeparator)) + val extraJavaOpts = driverExtraJavaOptions.map(Utils.splitCommandString).getOrElse(Seq.empty) + val sparkJavaOpts = Utils.sparkJavaOpts(conf) + val javaOpts = sparkJavaOpts ++ extraJavaOpts + val command = new Command( + mainClass, appArgs, environmentVariables, extraClassPath, extraLibraryPath, javaOpts) + val actualSuperviseDriver = superviseDriver.map(_.toBoolean).getOrElse(DEFAULT_SUPERVISE) + val actualDriverMemory = driverMemory.map(Utils.memoryStringToMb).getOrElse(DEFAULT_MEMORY) + val actualDriverCores = driverCores.map(_.toDouble).getOrElse(DEFAULT_CORES) + val submitDate = new Date() + val submissionId = newDriverId(submitDate) + + new MesosDriverDescription( + name, appResource, actualDriverMemory, actualDriverCores, actualSuperviseDriver, + command, request.sparkProperties, submissionId, submitDate) + } + + protected override def handleSubmit( + requestMessageJson: String, + requestMessage: SubmitRestProtocolMessage, + responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { + requestMessage match { + case submitRequest: CreateSubmissionRequest => + val driverDescription = buildDriverDescription(submitRequest) + val s = scheduler.submitDriver(driverDescription) + s.serverSparkVersion = sparkVersion + val unknownFields = findUnknownFields(requestMessageJson, requestMessage) + if (unknownFields.nonEmpty) { + // If there are fields that the server does not know about, warn the client + s.unknownFields = unknownFields + } + s + case unexpected => + responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST) + handleError(s"Received message of unexpected type ${unexpected.messageType}.") + } + } +} + +private[deploy] class MesosKillRequestServlet(scheduler: MesosClusterScheduler, conf: SparkConf) + extends KillRequestServlet { + protected override def handleKill(submissionId: String): KillSubmissionResponse = { + val k = scheduler.killDriver(submissionId) + k.serverSparkVersion = sparkVersion + k + } +} + +private[deploy] class MesosStatusRequestServlet(scheduler: MesosClusterScheduler, conf: SparkConf) + extends StatusRequestServlet { + protected override def handleStatus(submissionId: String): SubmissionStatusResponse = { + val d = scheduler.getDriverStatus(submissionId) + d.serverSparkVersion = sparkVersion + d + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 3ee2eb69e8a4e..8f3cc54051048 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -34,6 +34,7 @@ import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.{Command, ExecutorDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ +import org.apache.spark.deploy.ExternalShuffleService import org.apache.spark.deploy.master.{DriverState, Master} import org.apache.spark.deploy.worker.ui.WorkerWebUI import org.apache.spark.metrics.MetricsSystem @@ -61,7 +62,7 @@ private[worker] class Worker( assert (port > 0) // For worker and executor IDs - private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // Send a heartbeat every (heartbeat timeout) / 4 milliseconds private val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4 @@ -85,10 +86,10 @@ private[worker] class Worker( private val CLEANUP_ENABLED = conf.getBoolean("spark.worker.cleanup.enabled", false) // How often worker will clean up old app folders - private val CLEANUP_INTERVAL_MILLIS = + private val CLEANUP_INTERVAL_MILLIS = conf.getLong("spark.worker.cleanup.interval", 60 * 30) * 1000 // TTL for app folders/data; after TTL expires it will be cleaned up - private val APP_DATA_RETENTION_SECS = + private val APP_DATA_RETENTION_SECS = conf.getLong("spark.worker.cleanup.appDataTtl", 7 * 24 * 3600) private val testing: Boolean = sys.props.contains("spark.testing") @@ -112,7 +113,7 @@ private[worker] class Worker( } else { new File(sys.env.get("SPARK_HOME").getOrElse(".")) } - + var workDir: File = null val finishedExecutors = new HashMap[String, ExecutorRunner] val drivers = new HashMap[String, DriverRunner] @@ -122,7 +123,7 @@ private[worker] class Worker( val finishedApps = new HashSet[String] // The shuffle service is not actually started unless configured. - private val shuffleService = new StandaloneWorkerShuffleService(conf, securityMgr) + private val shuffleService = new ExternalShuffleService(conf, securityMgr) private val publicAddress = { val envVar = conf.getenv("SPARK_PUBLIC_DNS") @@ -134,7 +135,7 @@ private[worker] class Worker( private val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, securityMgr) private val workerSource = new WorkerSource(this) - + private var registrationRetryTimer: Option[Cancellable] = None var coresUsed = 0 diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 8af46f3327adb..79aed90b53e2f 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -57,7 +57,7 @@ private[spark] class CoarseGrainedExecutorBackend( logInfo("Connecting to driver: " + driverUrl) rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref => driver = Some(ref) - ref.sendWithReply[RegisteredExecutor.type]( + ref.ask[RegisteredExecutor.type]( RegisterExecutor(executorId, self, hostPort, cores, extractLogUrls)) } onComplete { case Success(msg) => Utils.tryLogNonFatalError { @@ -154,7 +154,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { executorConf, new SecurityManager(executorConf)) val driver = fetcher.setupEndpointRefByURI(driverUrl) - val props = driver.askWithReply[Seq[(String, String)]](RetrieveSparkProps) ++ + val props = driver.askWithRetry[Seq[(String, String)]](RetrieveSparkProps) ++ Seq[(String, String)](("spark.app.id", appId)) fetcher.shutdown() diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index f57e215c3f2ed..8f916e0502ecb 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -32,6 +32,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} +import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util._ /** @@ -178,6 +179,7 @@ private[spark] class Executor( } override def run(): Unit = { + val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager) val deserializeStartTime = System.currentTimeMillis() Thread.currentThread.setContextClassLoader(replClassLoader) val ser = env.closureSerializer.newInstance() @@ -190,6 +192,7 @@ private[spark] class Executor( val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask) updateDependencies(taskFiles, taskJars) task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader) + task.setTaskMemoryManager(taskMemoryManager) // If this task has been killed before we deserialized it, let's quit now. Otherwise, // continue executing the task. @@ -206,7 +209,21 @@ private[spark] class Executor( // Run the actual task and measure its runtime. taskStart = System.currentTimeMillis() - val value = task.run(taskAttemptId = taskId, attemptNumber = attemptNumber) + val value = try { + task.run(taskAttemptId = taskId, attemptNumber = attemptNumber) + } finally { + // Note: this memory freeing logic is duplicated in DAGScheduler.runLocallyWithinThread; + // when changing this, make sure to update both copies. + val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory() + if (freedMemory > 0) { + val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId" + if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) { + throw new SparkException(errMsg) + } else { + logError(errMsg) + } + } + } val taskFinish = System.currentTimeMillis() // If the task has been killed, let's fail it. @@ -424,7 +441,7 @@ private[spark] class Executor( val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId) try { - val response = heartbeatReceiverRef.askWithReply[HeartbeatResponse](message) + val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse](message) if (response.reregisterBlockManager) { logWarning("Told to re-register on heartbeat") env.blockManager.reregister() diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 0709b6d689e86..0756cdb2ed8e6 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -97,7 +97,7 @@ private[spark] object CompressionCodec { /** * :: DeveloperApi :: * LZ4 implementation of [[org.apache.spark.io.CompressionCodec]]. - * Block size can be configured by `spark.io.compression.lz4.block.size`. + * Block size can be configured by `spark.io.compression.lz4.blockSize`. * * Note: The wire protocol for this codec is not guaranteed to be compatible across versions * of Spark. This is intended for use as an internal compression utility within a single Spark @@ -107,7 +107,7 @@ private[spark] object CompressionCodec { class LZ4CompressionCodec(conf: SparkConf) extends CompressionCodec { override def compressedOutputStream(s: OutputStream): OutputStream = { - val blockSize = conf.getInt("spark.io.compression.lz4.block.size", 32768) + val blockSize = conf.getSizeAsBytes("spark.io.compression.lz4.blockSize", "32k").toInt new LZ4BlockOutputStream(s, blockSize) } @@ -137,7 +137,7 @@ class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec { /** * :: DeveloperApi :: * Snappy implementation of [[org.apache.spark.io.CompressionCodec]]. - * Block size can be configured by `spark.io.compression.snappy.block.size`. + * Block size can be configured by `spark.io.compression.snappy.blockSize`. * * Note: The wire protocol for this codec is not guaranteed to be compatible across versions * of Spark. This is intended for use as an internal compression utility within a single Spark @@ -153,7 +153,7 @@ class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec { } override def compressedOutputStream(s: OutputStream): OutputStream = { - val blockSize = conf.getInt("spark.io.compression.snappy.block.size", 32768) + val blockSize = conf.getSizeAsBytes("spark.io.compression.snappy.blockSize", "32k").toInt new SnappyOutputStream(s, blockSize) } diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala index 92b0641d0fb6e..7598ff617b399 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala @@ -60,6 +60,7 @@ class PartitionerAwareUnionRDD[T: ClassTag]( var rdds: Seq[RDD[T]] ) extends RDD[T](sc, rdds.map(x => new OneToOneDependency(x))) { require(rdds.length > 0) + require(rdds.forall(_.partitioner.isDefined)) require(rdds.flatMap(_.partitioner).toSet.size == 1, "Parent RDDs have different partitioners: " + rdds.flatMap(_.partitioner)) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala new file mode 100644 index 0000000000000..3e5b64265e919 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +/** + * A callback that [[RpcEndpoint]] can use it to send back a message or failure. It's thread-safe + * and can be called in any thread. + */ +private[spark] trait RpcCallContext { + + /** + * Reply a message to the sender. If the sender is [[RpcEndpoint]], its [[RpcEndpoint.receive]] + * will be called. + */ + def reply(response: Any): Unit + + /** + * Report a failure to the sender. + */ + def sendFailure(e: Throwable): Unit + + /** + * The sender of this message. + */ + def sender: RpcEndpointRef +} diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala new file mode 100644 index 0000000000000..d2b2baef1d8c4 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import org.apache.spark.SparkException + +/** + * A factory class to create the [[RpcEnv]]. It must have an empty constructor so that it can be + * created using Reflection. + */ +private[spark] trait RpcEnvFactory { + + def create(config: RpcEnvConfig): RpcEnv +} + +/** + * A trait that requires RpcEnv thread-safely sending messages to it. + * + * Thread-safety means processing of one message happens before processing of the next message by + * the same [[ThreadSafeRpcEndpoint]]. In the other words, changes to internal fields of a + * [[ThreadSafeRpcEndpoint]] are visible when processing the next message, and fields in the + * [[ThreadSafeRpcEndpoint]] need not be volatile or equivalent. + * + * However, there is no guarantee that the same thread will be executing the same + * [[ThreadSafeRpcEndpoint]] for different messages. + */ +private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint + + +/** + * An end point for the RPC that defines what functions to trigger given a message. + * + * It is guaranteed that `onStart`, `receive` and `onStop` will be called in sequence. + * + * The lift-cycle will be: + * + * constructor onStart receive* onStop + * + * Note: `receive` can be called concurrently. If you want `receive` is thread-safe, please use + * [[ThreadSafeRpcEndpoint]] + * + * If any error is thrown from one of [[RpcEndpoint]] methods except `onError`, `onError` will be + * invoked with the cause. If `onError` throws an error, [[RpcEnv]] will ignore it. + */ +private[spark] trait RpcEndpoint { + + /** + * The [[RpcEnv]] that this [[RpcEndpoint]] is registered to. + */ + val rpcEnv: RpcEnv + + /** + * The [[RpcEndpointRef]] of this [[RpcEndpoint]]. `self` will become valid when `onStart` is + * called. And `self` will become `null` when `onStop` is called. + * + * Note: Because before `onStart`, [[RpcEndpoint]] has not yet been registered and there is not + * valid [[RpcEndpointRef]] for it. So don't call `self` before `onStart` is called. + */ + final def self: RpcEndpointRef = { + require(rpcEnv != null, "rpcEnv has not been initialized") + rpcEnv.endpointRef(this) + } + + /** + * Process messages from [[RpcEndpointRef.send]] or [[RpcCallContext.reply)]]. If receiving a + * unmatched message, [[SparkException]] will be thrown and sent to `onError`. + */ + def receive: PartialFunction[Any, Unit] = { + case _ => throw new SparkException(self + " does not implement 'receive'") + } + + /** + * Process messages from [[RpcEndpointRef.ask]]. If receiving a unmatched message, + * [[SparkException]] will be thrown and sent to `onError`. + */ + def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case _ => context.sendFailure(new SparkException(self + " won't reply anything")) + } + + /** + * Invoked when any exception is thrown during handling messages. + */ + def onError(cause: Throwable): Unit = { + // By default, throw e and let RpcEnv handle it + throw cause + } + + /** + * Invoked before [[RpcEndpoint]] starts to handle any message. + */ + def onStart(): Unit = { + // By default, do nothing. + } + + /** + * Invoked when [[RpcEndpoint]] is stopping. + */ + def onStop(): Unit = { + // By default, do nothing. + } + + /** + * Invoked when `remoteAddress` is connected to the current node. + */ + def onConnected(remoteAddress: RpcAddress): Unit = { + // By default, do nothing. + } + + /** + * Invoked when `remoteAddress` is lost. + */ + def onDisconnected(remoteAddress: RpcAddress): Unit = { + // By default, do nothing. + } + + /** + * Invoked when some network error happens in the connection between the current node and + * `remoteAddress`. + */ + def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + // By default, do nothing. + } + + /** + * A convenient method to stop [[RpcEndpoint]]. + */ + final def stop(): Unit = { + val _self = self + if (_self != null) { + rpcEnv.stop(_self) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala new file mode 100644 index 0000000000000..69181edb9ad44 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import scala.concurrent.{Await, Future} +import scala.concurrent.duration.FiniteDuration +import scala.reflect.ClassTag + +import org.apache.spark.util.RpcUtils +import org.apache.spark.{SparkException, Logging, SparkConf} + +/** + * A reference for a remote [[RpcEndpoint]]. [[RpcEndpointRef]] is thread-safe. + */ +private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) + extends Serializable with Logging { + + private[this] val maxRetries = RpcUtils.numRetries(conf) + private[this] val retryWaitMs = RpcUtils.retryWaitMs(conf) + private[this] val defaultAskTimeout = RpcUtils.askTimeout(conf) + + /** + * return the address for the [[RpcEndpointRef]] + */ + def address: RpcAddress + + def name: String + + /** + * Sends a one-way asynchronous message. Fire-and-forget semantics. + */ + def send(message: Any): Unit + + /** + * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a [[Future]] to + * receive the reply within the specified timeout. + * + * This method only sends the message once and never retries. + */ + def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] + + /** + * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a [[Future]] to + * receive the reply within a default timeout. + * + * This method only sends the message once and never retries. + */ + def ask[T: ClassTag](message: Any): Future[T] = ask(message, defaultAskTimeout) + + /** + * Send a message to the corresponding [[RpcEndpoint]] and get its result within a default + * timeout, or throw a SparkException if this fails even after the default number of retries. + * The default `timeout` will be used in every trial of calling `sendWithReply`. Because this + * method retries, the message handling in the receiver side should be idempotent. + * + * Note: this is a blocking action which may cost a lot of time, so don't call it in an message + * loop of [[RpcEndpoint]]. + * + * @param message the message to send + * @tparam T type of the reply message + * @return the reply message from the corresponding [[RpcEndpoint]] + */ + def askWithRetry[T: ClassTag](message: Any): T = askWithRetry(message, defaultAskTimeout) + + /** + * Send a message to the corresponding [[RpcEndpoint.receive]] and get its result within a + * specified timeout, throw a SparkException if this fails even after the specified number of + * retries. `timeout` will be used in every trial of calling `sendWithReply`. Because this method + * retries, the message handling in the receiver side should be idempotent. + * + * Note: this is a blocking action which may cost a lot of time, so don't call it in an message + * loop of [[RpcEndpoint]]. + * + * @param message the message to send + * @param timeout the timeout duration + * @tparam T type of the reply message + * @return the reply message from the corresponding [[RpcEndpoint]] + */ + def askWithRetry[T: ClassTag](message: Any, timeout: FiniteDuration): T = { + // TODO: Consider removing multiple attempts + var attempts = 0 + var lastException: Exception = null + while (attempts < maxRetries) { + attempts += 1 + try { + val future = ask[T](message, timeout) + val result = Await.result(future, timeout) + if (result == null) { + throw new SparkException("Actor returned null") + } + return result + } catch { + case ie: InterruptedException => throw ie + case e: Exception => + lastException = e + logWarning(s"Error sending message [message = $message] in $attempts attempts", e) + } + Thread.sleep(retryWaitMs) + } + + throw new SparkException( + s"Error sending message [message = $message]", lastException) + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index a5336b7563802..12b6b28d4d7ec 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -20,13 +20,41 @@ package org.apache.spark.rpc import java.net.URI import scala.concurrent.{Await, Future} -import scala.concurrent.duration._ import scala.language.postfixOps -import scala.reflect.ClassTag -import org.apache.spark.{Logging, SparkException, SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.util.{RpcUtils, Utils} + +/** + * A RpcEnv implementation must have a [[RpcEnvFactory]] implementation with an empty constructor + * so that it can be created via Reflection. + */ +private[spark] object RpcEnv { + + private def getRpcEnvFactory(conf: SparkConf): RpcEnvFactory = { + // Add more RpcEnv implementations here + val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory") + val rpcEnvName = conf.get("spark.rpc", "akka") + val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName) + Class.forName(rpcEnvFactoryClassName, true, Utils.getContextOrSparkClassLoader). + newInstance().asInstanceOf[RpcEnvFactory] + } + + def create( + name: String, + host: String, + port: Int, + conf: SparkConf, + securityManager: SecurityManager): RpcEnv = { + // Using Reflection to create the RpcEnv to avoid to depend on Akka directly + val config = RpcEnvConfig(conf, name, host, port, securityManager) + getRpcEnvFactory(conf).create(config) + } + +} + + /** * An RPC environment. [[RpcEndpoint]]s need to register itself with a name to [[RpcEnv]] to * receives messages. Then [[RpcEnv]] will process messages sent from [[RpcEndpointRef]] or remote @@ -112,6 +140,7 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { def uriOf(systemName: String, address: RpcAddress, endpointName: String): String } + private[spark] case class RpcEnvConfig( conf: SparkConf, name: String, @@ -119,261 +148,9 @@ private[spark] case class RpcEnvConfig( port: Int, securityManager: SecurityManager) -/** - * A RpcEnv implementation must have a [[RpcEnvFactory]] implementation with an empty constructor - * so that it can be created via Reflection. - */ -private[spark] object RpcEnv { - - private def getRpcEnvFactory(conf: SparkConf): RpcEnvFactory = { - // Add more RpcEnv implementations here - val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory") - val rpcEnvName = conf.get("spark.rpc", "akka") - val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName) - Class.forName(rpcEnvFactoryClassName, true, Utils.getContextOrSparkClassLoader). - newInstance().asInstanceOf[RpcEnvFactory] - } - - def create( - name: String, - host: String, - port: Int, - conf: SparkConf, - securityManager: SecurityManager): RpcEnv = { - // Using Reflection to create the RpcEnv to avoid to depend on Akka directly - val config = RpcEnvConfig(conf, name, host, port, securityManager) - getRpcEnvFactory(conf).create(config) - } - -} - -/** - * A factory class to create the [[RpcEnv]]. It must have an empty constructor so that it can be - * created using Reflection. - */ -private[spark] trait RpcEnvFactory { - - def create(config: RpcEnvConfig): RpcEnv -} /** - * An end point for the RPC that defines what functions to trigger given a message. - * - * It is guaranteed that `onStart`, `receive` and `onStop` will be called in sequence. - * - * The lift-cycle will be: - * - * constructor onStart receive* onStop - * - * Note: `receive` can be called concurrently. If you want `receive` is thread-safe, please use - * [[ThreadSafeRpcEndpoint]] - * - * If any error is thrown from one of [[RpcEndpoint]] methods except `onError`, `onError` will be - * invoked with the cause. If `onError` throws an error, [[RpcEnv]] will ignore it. - */ -private[spark] trait RpcEndpoint { - - /** - * The [[RpcEnv]] that this [[RpcEndpoint]] is registered to. - */ - val rpcEnv: RpcEnv - - /** - * The [[RpcEndpointRef]] of this [[RpcEndpoint]]. `self` will become valid when `onStart` is - * called. And `self` will become `null` when `onStop` is called. - * - * Note: Because before `onStart`, [[RpcEndpoint]] has not yet been registered and there is not - * valid [[RpcEndpointRef]] for it. So don't call `self` before `onStart` is called. - */ - final def self: RpcEndpointRef = { - require(rpcEnv != null, "rpcEnv has not been initialized") - rpcEnv.endpointRef(this) - } - - /** - * Process messages from [[RpcEndpointRef.send]] or [[RpcCallContext.reply)]]. If receiving a - * unmatched message, [[SparkException]] will be thrown and sent to `onError`. - */ - def receive: PartialFunction[Any, Unit] = { - case _ => throw new SparkException(self + " does not implement 'receive'") - } - - /** - * Process messages from [[RpcEndpointRef.sendWithReply]]. If receiving a unmatched message, - * [[SparkException]] will be thrown and sent to `onError`. - */ - def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case _ => context.sendFailure(new SparkException(self + " won't reply anything")) - } - - /** - * Call onError when any exception is thrown during handling messages. - * - * @param cause - */ - def onError(cause: Throwable): Unit = { - // By default, throw e and let RpcEnv handle it - throw cause - } - - /** - * Invoked before [[RpcEndpoint]] starts to handle any message. - */ - def onStart(): Unit = { - // By default, do nothing. - } - - /** - * Invoked when [[RpcEndpoint]] is stopping. - */ - def onStop(): Unit = { - // By default, do nothing. - } - - /** - * Invoked when `remoteAddress` is connected to the current node. - */ - def onConnected(remoteAddress: RpcAddress): Unit = { - // By default, do nothing. - } - - /** - * Invoked when `remoteAddress` is lost. - */ - def onDisconnected(remoteAddress: RpcAddress): Unit = { - // By default, do nothing. - } - - /** - * Invoked when some network error happens in the connection between the current node and - * `remoteAddress`. - */ - def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { - // By default, do nothing. - } - - /** - * A convenient method to stop [[RpcEndpoint]]. - */ - final def stop(): Unit = { - val _self = self - if (_self != null) { - rpcEnv.stop(_self) - } - } -} - -/** - * A trait that requires RpcEnv thread-safely sending messages to it. - * - * Thread-safety means processing of one message happens before processing of the next message by - * the same [[ThreadSafeRpcEndpoint]]. In the other words, changes to internal fields of a - * [[ThreadSafeRpcEndpoint]] are visible when processing the next message, and fields in the - * [[ThreadSafeRpcEndpoint]] need not be volatile or equivalent. - * - * However, there is no guarantee that the same thread will be executing the same - * [[ThreadSafeRpcEndpoint]] for different messages. - */ -trait ThreadSafeRpcEndpoint extends RpcEndpoint - -/** - * A reference for a remote [[RpcEndpoint]]. [[RpcEndpointRef]] is thread-safe. - */ -private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) - extends Serializable with Logging { - - private[this] val maxRetries = RpcUtils.numRetries(conf) - private[this] val retryWaitMs = RpcUtils.retryWaitMs(conf) - private[this] val defaultAskTimeout = RpcUtils.askTimeout(conf) - - /** - * return the address for the [[RpcEndpointRef]] - */ - def address: RpcAddress - - def name: String - - /** - * Sends a one-way asynchronous message. Fire-and-forget semantics. - */ - def send(message: Any): Unit - - /** - * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a `Future` to - * receive the reply within a default timeout. - * - * This method only sends the message once and never retries. - */ - def sendWithReply[T: ClassTag](message: Any): Future[T] = - sendWithReply(message, defaultAskTimeout) - - /** - * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a `Future` to - * receive the reply within the specified timeout. - * - * This method only sends the message once and never retries. - */ - def sendWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] - - /** - * Send a message to the corresponding [[RpcEndpoint]] and get its result within a default - * timeout, or throw a SparkException if this fails even after the default number of retries. - * The default `timeout` will be used in every trial of calling `sendWithReply`. Because this - * method retries, the message handling in the receiver side should be idempotent. - * - * Note: this is a blocking action which may cost a lot of time, so don't call it in an message - * loop of [[RpcEndpoint]]. - * - * @param message the message to send - * @tparam T type of the reply message - * @return the reply message from the corresponding [[RpcEndpoint]] - */ - def askWithReply[T: ClassTag](message: Any): T = askWithReply(message, defaultAskTimeout) - - /** - * Send a message to the corresponding [[RpcEndpoint.receive]] and get its result within a - * specified timeout, throw a SparkException if this fails even after the specified number of - * retries. `timeout` will be used in every trial of calling `sendWithReply`. Because this method - * retries, the message handling in the receiver side should be idempotent. - * - * Note: this is a blocking action which may cost a lot of time, so don't call it in an message - * loop of [[RpcEndpoint]]. - * - * @param message the message to send - * @param timeout the timeout duration - * @tparam T type of the reply message - * @return the reply message from the corresponding [[RpcEndpoint]] - */ - def askWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): T = { - // TODO: Consider removing multiple attempts - var attempts = 0 - var lastException: Exception = null - while (attempts < maxRetries) { - attempts += 1 - try { - val future = sendWithReply[T](message, timeout) - val result = Await.result(future, timeout) - if (result == null) { - throw new SparkException("Actor returned null") - } - return result - } catch { - case ie: InterruptedException => throw ie - case e: Exception => - lastException = e - logWarning(s"Error sending message [message = $message] in $attempts attempts", e) - } - Thread.sleep(retryWaitMs) - } - - throw new SparkException( - s"Error sending message [message = $message]", lastException) - } - -} - -/** - * Represent a host with a port + * Represents a host and port. */ private[spark] case class RpcAddress(host: String, port: Int) { // TODO do we need to add the type of RpcEnv in the address? @@ -383,6 +160,7 @@ private[spark] case class RpcAddress(host: String, port: Int) { override val toString: String = hostPort } + private[spark] object RpcAddress { /** @@ -404,26 +182,3 @@ private[spark] object RpcAddress { RpcAddress(host, port) } } - -/** - * A callback that [[RpcEndpoint]] can use it to send back a message or failure. It's thread-safe - * and can be called in any thread. - */ -private[spark] trait RpcCallContext { - - /** - * Reply a message to the sender. If the sender is [[RpcEndpoint]], its [[RpcEndpoint.receive]] - * will be called. - */ - def reply(response: Any): Unit - - /** - * Report a failure to the sender. - */ - def sendFailure(e: Throwable): Unit - - /** - * The sender of this message. - */ - def sender: RpcEndpointRef -} diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 652e52f2b2e73..ba0d468f111ef 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -293,7 +293,7 @@ private[akka] class AkkaRpcEndpointRef( actorRef ! AkkaMessage(message, false) } - override def sendWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] = { + override def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] = { import scala.concurrent.ExecutionContext.Implicits.global actorRef.ask(AkkaMessage(message, true))(timeout).flatMap { case msg @ AkkaMessage(message, reply) => diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 8c4bff4e83afc..b511c306ca508 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -34,6 +34,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD import org.apache.spark.storage._ +import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util._ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat @@ -166,7 +167,7 @@ class DAGScheduler( taskMetrics: Array[(Long, Int, Int, TaskMetrics)], // (taskId, stageId, stateAttempt, metrics) blockManagerId: BlockManagerId): Boolean = { listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics)) - blockManagerMaster.driverEndpoint.askWithReply[Boolean]( + blockManagerMaster.driverEndpoint.askWithRetry[Boolean]( BlockManagerHeartbeat(blockManagerId), 600 seconds) } @@ -643,8 +644,15 @@ class DAGScheduler( try { val rdd = job.finalStage.rdd val split = rdd.partitions(job.partitions(0)) - val taskContext = new TaskContextImpl(job.finalStage.id, job.partitions(0), taskAttemptId = 0, - attemptNumber = 0, runningLocally = true) + val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager) + val taskContext = + new TaskContextImpl( + job.finalStage.id, + job.partitions(0), + taskAttemptId = 0, + attemptNumber = 0, + taskMemoryManager = taskMemoryManager, + runningLocally = true) TaskContext.setTaskContext(taskContext) try { val result = job.func(taskContext, rdd.iterator(split, taskContext)) @@ -652,6 +660,16 @@ class DAGScheduler( } finally { taskContext.markTaskCompleted() TaskContext.unset() + // Note: this memory freeing logic is duplicated in Executor.run(); when changing this, + // make sure to update both copies. + val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory() + if (freedMemory > 0) { + if (sc.getConf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) { + throw new SparkException(s"Managed memory leak detected; size = $freedMemory bytes") + } else { + logError(s"Managed memory leak detected; size = $freedMemory bytes") + } + } } } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index 7c184b1dcb308..0b1d47cff3746 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -85,7 +85,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging { val msg = AskPermissionToCommitOutput(stage, partition, attempt) coordinatorRef match { case Some(endpointRef) => - endpointRef.askWithReply[Boolean](msg) + endpointRef.askWithRetry[Boolean](msg) case None => logError( "canCommit called after coordinator was stopped (is SparkEnv shutdown in progress)?") diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index b09b19e2ac9e7..586d1e06204c1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -25,6 +25,7 @@ import scala.collection.mutable.HashMap import org.apache.spark.{TaskContextImpl, TaskContext} import org.apache.spark.executor.TaskMetrics import org.apache.spark.serializer.SerializerInstance +import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util.ByteBufferInputStream import org.apache.spark.util.Utils @@ -52,8 +53,13 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex * @return the result of the task */ final def run(taskAttemptId: Long, attemptNumber: Int): T = { - context = new TaskContextImpl(stageId = stageId, partitionId = partitionId, - taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, runningLocally = false) + context = new TaskContextImpl( + stageId = stageId, + partitionId = partitionId, + taskAttemptId = taskAttemptId, + attemptNumber = attemptNumber, + taskMemoryManager = taskMemoryManager, + runningLocally = false) TaskContext.setTaskContext(context) context.taskMetrics.setHostname(Utils.localHostName()) taskThread = Thread.currentThread() @@ -68,6 +74,12 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex } } + private var taskMemoryManager: TaskMemoryManager = _ + + def setTaskMemoryManager(taskMemoryManager: TaskMemoryManager): Unit = { + this.taskMemoryManager = taskMemoryManager + } + def runTask(context: TaskContext): T def preferredLocations: Seq[TaskLocation] = Nil diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 9656fb76858ea..7352fa1fe9ebd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -252,7 +252,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp try { if (driverEndpoint != null) { logInfo("Shutting down all executors") - driverEndpoint.askWithReply[Boolean](StopExecutors) + driverEndpoint.askWithRetry[Boolean](StopExecutors) } } catch { case e: Exception => @@ -264,7 +264,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp stopExecutors() try { if (driverEndpoint != null) { - driverEndpoint.askWithReply[Boolean](StopDriver) + driverEndpoint.askWithRetry[Boolean](StopDriver) } } catch { case e: Exception => @@ -287,7 +287,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Called by subclasses when notified of a lost worker def removeExecutor(executorId: String, reason: String) { try { - driverEndpoint.askWithReply[Boolean](RemoveExecutor(executorId, reason)) + driverEndpoint.askWithRetry[Boolean](RemoveExecutor(executorId, reason)) } catch { case e: Exception => throw new SparkException("Error notifying standalone scheduler's driver endpoint", e) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index d987c7d563579..2a3a5d925d06f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -53,14 +53,14 @@ private[spark] abstract class YarnSchedulerBackend( * This includes executors already pending or running. */ override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { - yarnSchedulerEndpoint.askWithReply[Boolean](RequestExecutors(requestedTotal)) + yarnSchedulerEndpoint.askWithRetry[Boolean](RequestExecutors(requestedTotal)) } /** * Request that the ApplicationMaster kill the specified executors. */ override def doKillExecutors(executorIds: Seq[String]): Boolean = { - yarnSchedulerEndpoint.askWithReply[Boolean](KillExecutors(executorIds)) + yarnSchedulerEndpoint.askWithRetry[Boolean](KillExecutors(executorIds)) } override def sufficientResourcesRegistered(): Boolean = { @@ -115,7 +115,7 @@ private[spark] abstract class YarnSchedulerBackend( amEndpoint match { case Some(am) => Future { - context.reply(am.askWithReply[Boolean](r)) + context.reply(am.askWithRetry[Boolean](r)) } onFailure { case NonFatal(e) => logError(s"Sending $r to AM was unsuccessful", e) @@ -130,7 +130,7 @@ private[spark] abstract class YarnSchedulerBackend( amEndpoint match { case Some(am) => Future { - context.reply(am.askWithReply[Boolean](k)) + context.reply(am.askWithRetry[Boolean](k)) } onFailure { case NonFatal(e) => logError(s"Sending $k to AM was unsuccessful", e) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 82f652dae0378..3412301e64fd7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -18,20 +18,17 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File -import java.util.{List => JList} -import java.util.Collections +import java.util.{Collections, List => JList} import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet} -import org.apache.mesos.{Scheduler => MScheduler} -import org.apache.mesos._ -import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} - -import org.apache.spark.{Logging, SparkContext, SparkEnv, SparkException, TaskState} +import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} +import org.apache.mesos.{Scheduler => MScheduler, _} import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.util.{Utils, AkkaUtils} +import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskState} /** * A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds @@ -49,17 +46,10 @@ private[spark] class CoarseMesosSchedulerBackend( master: String) extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with MScheduler - with Logging { + with MesosSchedulerUtils { val MAX_SLAVE_FAILURES = 2 // Blacklist a slave after this many failures - // Lock used to wait for scheduler to be registered - var isRegistered = false - val registeredLock = new Object() - - // Driver for talking to Mesos - var driver: SchedulerDriver = null - // Maximum number of cores to acquire (TODO: we'll need more flexible controls here) val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt @@ -87,26 +77,8 @@ private[spark] class CoarseMesosSchedulerBackend( override def start() { super.start() - - synchronized { - new Thread("CoarseMesosSchedulerBackend driver") { - setDaemon(true) - override def run() { - val scheduler = CoarseMesosSchedulerBackend.this - val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() - driver = new MesosSchedulerDriver(scheduler, fwInfo, master) - try { { - val ret = driver.run() - logInfo("driver.run() returned with code " + ret) - } - } catch { - case e: Exception => logError("driver.run() failed", e) - } - } - }.start() - - waitForRegister() - } + val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() + startScheduler(master, CoarseMesosSchedulerBackend.this, fwInfo) } def createCommand(offer: Offer, numCores: Int): CommandInfo = { @@ -150,8 +122,10 @@ private[spark] class CoarseMesosSchedulerBackend( conf.get("spark.driver.port"), CoarseGrainedSchedulerBackend.ENDPOINT_NAME) - val uri = conf.get("spark.executor.uri", null) - if (uri == null) { + val uri = conf.getOption("spark.executor.uri") + .orElse(Option(System.getenv("SPARK_EXECUTOR_URI"))) + + if (uri.isEmpty) { val runScript = new File(executorSparkHome, "./bin/spark-class").getCanonicalPath command.setValue( "%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend" @@ -164,7 +138,7 @@ private[spark] class CoarseMesosSchedulerBackend( } else { // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". - val basename = uri.split('/').last.split('.').head + val basename = uri.get.split('/').last.split('.').head command.setValue( s"cd $basename*; $prefixEnv " + "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend" + @@ -173,7 +147,7 @@ private[spark] class CoarseMesosSchedulerBackend( s" --hostname ${offer.getHostname}" + s" --cores $numCores" + s" --app-id $appId") - command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) + command.addUris(CommandInfo.URI.newBuilder().setValue(uri.get)) } command.build() } @@ -183,18 +157,7 @@ private[spark] class CoarseMesosSchedulerBackend( override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { appId = frameworkId.getValue logInfo("Registered as framework ID " + appId) - registeredLock.synchronized { - isRegistered = true - registeredLock.notifyAll() - } - } - - def waitForRegister() { - registeredLock.synchronized { - while (!isRegistered) { - registeredLock.wait() - } - } + markRegistered() } override def disconnected(d: SchedulerDriver) {} @@ -245,14 +208,6 @@ private[spark] class CoarseMesosSchedulerBackend( } } - /** Helper function to pull out a resource from a Mesos Resources protobuf */ - private def getResource(res: JList[Resource], name: String): Double = { - for (r <- res if r.getName == name) { - return r.getScalar.getValue - } - 0 - } - /** Build a Mesos resource protobuf object */ private def createResource(resourceName: String, quantity: Double): Protos.Resource = { Resource.newBuilder() @@ -284,7 +239,8 @@ private[spark] class CoarseMesosSchedulerBackend( "is Spark installed on it?") } } - driver.reviveOffers() // In case we'd rejected everything before but have now lost a node + // In case we'd rejected everything before but have now lost a node + mesosDriver.reviveOffers() } } } @@ -296,8 +252,8 @@ private[spark] class CoarseMesosSchedulerBackend( override def stop() { super.stop() - if (driver != null) { - driver.stop() + if (mesosDriver != null) { + mesosDriver.stop() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala new file mode 100644 index 0000000000000..3efc536f1456c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import scala.collection.JavaConversions._ + +import org.apache.curator.framework.CuratorFramework +import org.apache.zookeeper.CreateMode +import org.apache.zookeeper.KeeperException.NoNodeException + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.deploy.SparkCuratorUtil +import org.apache.spark.util.Utils + +/** + * Persistence engine factory that is responsible for creating new persistence engines + * to store Mesos cluster mode state. + */ +private[spark] abstract class MesosClusterPersistenceEngineFactory(conf: SparkConf) { + def createEngine(path: String): MesosClusterPersistenceEngine +} + +/** + * Mesos cluster persistence engine is responsible for persisting Mesos cluster mode + * specific state, so that on failover all the state can be recovered and the scheduler + * can resume managing the drivers. + */ +private[spark] trait MesosClusterPersistenceEngine { + def persist(name: String, obj: Object): Unit + def expunge(name: String): Unit + def fetch[T](name: String): Option[T] + def fetchAll[T](): Iterable[T] +} + +/** + * Zookeeper backed persistence engine factory. + * All Zk engines created from this factory shares the same Zookeeper client, so + * all of them reuses the same connection pool. + */ +private[spark] class ZookeeperMesosClusterPersistenceEngineFactory(conf: SparkConf) + extends MesosClusterPersistenceEngineFactory(conf) { + + lazy val zk = SparkCuratorUtil.newClient(conf, "spark.mesos.deploy.zookeeper.url") + + def createEngine(path: String): MesosClusterPersistenceEngine = { + new ZookeeperMesosClusterPersistenceEngine(path, zk, conf) + } +} + +/** + * Black hole persistence engine factory that creates black hole + * persistence engines, which stores nothing. + */ +private[spark] class BlackHoleMesosClusterPersistenceEngineFactory + extends MesosClusterPersistenceEngineFactory(null) { + def createEngine(path: String): MesosClusterPersistenceEngine = { + new BlackHoleMesosClusterPersistenceEngine + } +} + +/** + * Black hole persistence engine that stores nothing. + */ +private[spark] class BlackHoleMesosClusterPersistenceEngine extends MesosClusterPersistenceEngine { + override def persist(name: String, obj: Object): Unit = {} + override def fetch[T](name: String): Option[T] = None + override def expunge(name: String): Unit = {} + override def fetchAll[T](): Iterable[T] = Iterable.empty[T] +} + +/** + * Zookeeper based Mesos cluster persistence engine, that stores cluster mode state + * into Zookeeper. Each engine object is operating under one folder in Zookeeper, but + * reuses a shared Zookeeper client. + */ +private[spark] class ZookeeperMesosClusterPersistenceEngine( + baseDir: String, + zk: CuratorFramework, + conf: SparkConf) + extends MesosClusterPersistenceEngine with Logging { + private val WORKING_DIR = + conf.get("spark.deploy.zookeeper.dir", "/spark_mesos_dispatcher") + "/" + baseDir + + SparkCuratorUtil.mkdir(zk, WORKING_DIR) + + def path(name: String): String = { + WORKING_DIR + "/" + name + } + + override def expunge(name: String): Unit = { + zk.delete().forPath(path(name)) + } + + override def persist(name: String, obj: Object): Unit = { + val serialized = Utils.serialize(obj) + val zkPath = path(name) + zk.create().withMode(CreateMode.PERSISTENT).forPath(zkPath, serialized) + } + + override def fetch[T](name: String): Option[T] = { + val zkPath = path(name) + + try { + val fileData = zk.getData().forPath(zkPath) + Some(Utils.deserialize[T](fileData)) + } catch { + case e: NoNodeException => None + case e: Exception => { + logWarning("Exception while reading persisted file, deleting", e) + zk.delete().forPath(zkPath) + None + } + } + } + + override def fetchAll[T](): Iterable[T] = { + zk.getChildren.forPath(WORKING_DIR).map(fetch[T]).flatten + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala new file mode 100644 index 0000000000000..0396e62be5309 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -0,0 +1,608 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import java.io.File +import java.util.concurrent.locks.ReentrantLock +import java.util.{Collections, Date, List => JList} + +import scala.collection.JavaConversions._ +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.mesos.Protos.Environment.Variable +import org.apache.mesos.Protos.TaskStatus.Reason +import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} +import org.apache.mesos.{Scheduler, SchedulerDriver} +import org.apache.spark.deploy.mesos.MesosDriverDescription +import org.apache.spark.deploy.rest.{CreateSubmissionResponse, KillSubmissionResponse, SubmissionStatusResponse} +import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.util.Utils +import org.apache.spark.{SecurityManager, SparkConf, SparkException, TaskState} + + +/** + * Tracks the current state of a Mesos Task that runs a Spark driver. + * @param driverDescription Submitted driver description from + * [[org.apache.spark.deploy.rest.mesos.MesosRestServer]] + * @param taskId Mesos TaskID generated for the task + * @param slaveId Slave ID that the task is assigned to + * @param mesosTaskStatus The last known task status update. + * @param startDate The date the task was launched + */ +private[spark] class MesosClusterSubmissionState( + val driverDescription: MesosDriverDescription, + val taskId: TaskID, + val slaveId: SlaveID, + var mesosTaskStatus: Option[TaskStatus], + var startDate: Date) + extends Serializable { + + def copy(): MesosClusterSubmissionState = { + new MesosClusterSubmissionState( + driverDescription, taskId, slaveId, mesosTaskStatus, startDate) + } +} + +/** + * Tracks the retry state of a driver, which includes the next time it should be scheduled + * and necessary information to do exponential backoff. + * This class is not thread-safe, and we expect the caller to handle synchronizing state. + * @param lastFailureStatus Last Task status when it failed. + * @param retries Number of times it has been retried. + * @param nextRetry Time at which it should be retried next + * @param waitTime The amount of time driver is scheduled to wait until next retry. + */ +private[spark] class MesosClusterRetryState( + val lastFailureStatus: TaskStatus, + val retries: Int, + val nextRetry: Date, + val waitTime: Int) extends Serializable { + def copy(): MesosClusterRetryState = + new MesosClusterRetryState(lastFailureStatus, retries, nextRetry, waitTime) +} + +/** + * The full state of the cluster scheduler, currently being used for displaying + * information on the UI. + * @param frameworkId Mesos Framework id for the cluster scheduler. + * @param masterUrl The Mesos master url + * @param queuedDrivers All drivers queued to be launched + * @param launchedDrivers All launched or running drivers + * @param finishedDrivers All terminated drivers + * @param pendingRetryDrivers All drivers pending to be retried + */ +private[spark] class MesosClusterSchedulerState( + val frameworkId: String, + val masterUrl: Option[String], + val queuedDrivers: Iterable[MesosDriverDescription], + val launchedDrivers: Iterable[MesosClusterSubmissionState], + val finishedDrivers: Iterable[MesosClusterSubmissionState], + val pendingRetryDrivers: Iterable[MesosDriverDescription]) + +/** + * A Mesos scheduler that is responsible for launching submitted Spark drivers in cluster mode + * as Mesos tasks in a Mesos cluster. + * All drivers are launched asynchronously by the framework, which will eventually be launched + * by one of the slaves in the cluster. The results of the driver will be stored in slave's task + * sandbox which is accessible by visiting the Mesos UI. + * This scheduler supports recovery by persisting all its state and performs task reconciliation + * on recover, which gets all the latest state for all the drivers from Mesos master. + */ +private[spark] class MesosClusterScheduler( + engineFactory: MesosClusterPersistenceEngineFactory, + conf: SparkConf) + extends Scheduler with MesosSchedulerUtils { + var frameworkUrl: String = _ + private val metricsSystem = + MetricsSystem.createMetricsSystem("mesos_cluster", conf, new SecurityManager(conf)) + private val master = conf.get("spark.master") + private val appName = conf.get("spark.app.name") + private val queuedCapacity = conf.getInt("spark.mesos.maxDrivers", 200) + private val retainedDrivers = conf.getInt("spark.mesos.retainedDrivers", 200) + private val maxRetryWaitTime = conf.getInt("spark.mesos.cluster.retry.wait.max", 60) // 1 minute + private val schedulerState = engineFactory.createEngine("scheduler") + private val stateLock = new ReentrantLock() + private val finishedDrivers = + new mutable.ArrayBuffer[MesosClusterSubmissionState](retainedDrivers) + private var frameworkId: String = null + // Holds all the launched drivers and current launch state, keyed by driver id. + private val launchedDrivers = new mutable.HashMap[String, MesosClusterSubmissionState]() + // Holds a map of driver id to expected slave id that is passed to Mesos for reconciliation. + // All drivers that are loaded after failover are added here, as we need get the latest + // state of the tasks from Mesos. + private val pendingRecover = new mutable.HashMap[String, SlaveID]() + // Stores all the submitted drivers that hasn't been launched. + private val queuedDrivers = new ArrayBuffer[MesosDriverDescription]() + // All supervised drivers that are waiting to retry after termination. + private val pendingRetryDrivers = new ArrayBuffer[MesosDriverDescription]() + private val queuedDriversState = engineFactory.createEngine("driverQueue") + private val launchedDriversState = engineFactory.createEngine("launchedDrivers") + private val pendingRetryDriversState = engineFactory.createEngine("retryList") + // Flag to mark if the scheduler is ready to be called, which is until the scheduler + // is registered with Mesos master. + @volatile protected var ready = false + private var masterInfo: Option[MasterInfo] = None + + def submitDriver(desc: MesosDriverDescription): CreateSubmissionResponse = { + val c = new CreateSubmissionResponse + if (!ready) { + c.success = false + c.message = "Scheduler is not ready to take requests" + return c + } + + stateLock.synchronized { + if (isQueueFull()) { + c.success = false + c.message = "Already reached maximum submission size" + return c + } + c.submissionId = desc.submissionId + queuedDriversState.persist(desc.submissionId, desc) + queuedDrivers += desc + c.success = true + } + c + } + + def killDriver(submissionId: String): KillSubmissionResponse = { + val k = new KillSubmissionResponse + if (!ready) { + k.success = false + k.message = "Scheduler is not ready to take requests" + return k + } + k.submissionId = submissionId + stateLock.synchronized { + // We look for the requested driver in the following places: + // 1. Check if submission is running or launched. + // 2. Check if it's still queued. + // 3. Check if it's in the retry list. + // 4. Check if it has already completed. + if (launchedDrivers.contains(submissionId)) { + val task = launchedDrivers(submissionId) + mesosDriver.killTask(task.taskId) + k.success = true + k.message = "Killing running driver" + } else if (removeFromQueuedDrivers(submissionId)) { + k.success = true + k.message = "Removed driver while it's still pending" + } else if (removeFromPendingRetryDrivers(submissionId)) { + k.success = true + k.message = "Removed driver while it's being retried" + } else if (finishedDrivers.exists(_.driverDescription.submissionId.equals(submissionId))) { + k.success = false + k.message = "Driver already terminated" + } else { + k.success = false + k.message = "Cannot find driver" + } + } + k + } + + def getDriverStatus(submissionId: String): SubmissionStatusResponse = { + val s = new SubmissionStatusResponse + if (!ready) { + s.success = false + s.message = "Scheduler is not ready to take requests" + return s + } + s.submissionId = submissionId + stateLock.synchronized { + if (queuedDrivers.exists(_.submissionId.equals(submissionId))) { + s.success = true + s.driverState = "QUEUED" + } else if (launchedDrivers.contains(submissionId)) { + s.success = true + s.driverState = "RUNNING" + launchedDrivers(submissionId).mesosTaskStatus.foreach(state => s.message = state.toString) + } else if (finishedDrivers.exists(_.driverDescription.submissionId.equals(submissionId))) { + s.success = true + s.driverState = "FINISHED" + finishedDrivers + .find(d => d.driverDescription.submissionId.equals(submissionId)).get.mesosTaskStatus + .foreach(state => s.message = state.toString) + } else if (pendingRetryDrivers.exists(_.submissionId.equals(submissionId))) { + val status = pendingRetryDrivers.find(_.submissionId.equals(submissionId)) + .get.retryState.get.lastFailureStatus + s.success = true + s.driverState = "RETRYING" + s.message = status.toString + } else { + s.success = false + s.driverState = "NOT_FOUND" + } + } + s + } + + private def isQueueFull(): Boolean = launchedDrivers.size >= queuedCapacity + + /** + * Recover scheduler state that is persisted. + * We still need to do task reconciliation to be up to date of the latest task states + * as it might have changed while the scheduler is failing over. + */ + private def recoverState(): Unit = { + stateLock.synchronized { + launchedDriversState.fetchAll[MesosClusterSubmissionState]().foreach { state => + launchedDrivers(state.taskId.getValue) = state + pendingRecover(state.taskId.getValue) = state.slaveId + } + queuedDriversState.fetchAll[MesosDriverDescription]().foreach(d => queuedDrivers += d) + // There is potential timing issue where a queued driver might have been launched + // but the scheduler shuts down before the queued driver was able to be removed + // from the queue. We try to mitigate this issue by walking through all queued drivers + // and remove if they're already launched. + queuedDrivers + .filter(d => launchedDrivers.contains(d.submissionId)) + .foreach(d => removeFromQueuedDrivers(d.submissionId)) + pendingRetryDriversState.fetchAll[MesosDriverDescription]() + .foreach(s => pendingRetryDrivers += s) + // TODO: Consider storing finished drivers so we can show them on the UI after + // failover. For now we clear the history on each recovery. + finishedDrivers.clear() + } + } + + /** + * Starts the cluster scheduler and wait until the scheduler is registered. + * This also marks the scheduler to be ready for requests. + */ + def start(): Unit = { + // TODO: Implement leader election to make sure only one framework running in the cluster. + val fwId = schedulerState.fetch[String]("frameworkId") + val builder = FrameworkInfo.newBuilder() + .setUser(Utils.getCurrentUserName()) + .setName(appName) + .setWebuiUrl(frameworkUrl) + .setCheckpoint(true) + .setFailoverTimeout(Integer.MAX_VALUE) // Setting to max so tasks keep running on crash + fwId.foreach { id => + builder.setId(FrameworkID.newBuilder().setValue(id).build()) + frameworkId = id + } + recoverState() + metricsSystem.registerSource(new MesosClusterSchedulerSource(this)) + metricsSystem.start() + startScheduler(master, MesosClusterScheduler.this, builder.build()) + ready = true + } + + def stop(): Unit = { + ready = false + metricsSystem.report() + metricsSystem.stop() + mesosDriver.stop(true) + } + + override def registered( + driver: SchedulerDriver, + newFrameworkId: FrameworkID, + masterInfo: MasterInfo): Unit = { + logInfo("Registered as framework ID " + newFrameworkId.getValue) + if (newFrameworkId.getValue != frameworkId) { + frameworkId = newFrameworkId.getValue + schedulerState.persist("frameworkId", frameworkId) + } + markRegistered() + + stateLock.synchronized { + this.masterInfo = Some(masterInfo) + if (!pendingRecover.isEmpty) { + // Start task reconciliation if we need to recover. + val statuses = pendingRecover.collect { + case (taskId, slaveId) => + val newStatus = TaskStatus.newBuilder() + .setTaskId(TaskID.newBuilder().setValue(taskId).build()) + .setSlaveId(slaveId) + .setState(MesosTaskState.TASK_STAGING) + .build() + launchedDrivers.get(taskId).map(_.mesosTaskStatus.getOrElse(newStatus)) + .getOrElse(newStatus) + } + // TODO: Page the status updates to avoid trying to reconcile + // a large amount of tasks at once. + driver.reconcileTasks(statuses) + } + } + } + + private def buildDriverCommand(desc: MesosDriverDescription): CommandInfo = { + val appJar = CommandInfo.URI.newBuilder() + .setValue(desc.jarUrl.stripPrefix("file:").stripPrefix("local:")).build() + val builder = CommandInfo.newBuilder().addUris(appJar) + val entries = + (conf.getOption("spark.executor.extraLibraryPath").toList ++ + desc.command.libraryPathEntries) + val prefixEnv = if (!entries.isEmpty) { + Utils.libraryPathEnvPrefix(entries) + } else { + "" + } + val envBuilder = Environment.newBuilder() + desc.command.environment.foreach { case (k, v) => + envBuilder.addVariables(Variable.newBuilder().setName(k).setValue(v).build()) + } + // Pass all spark properties to executor. + val executorOpts = desc.schedulerProperties.map { case (k, v) => s"-D$k=$v" }.mkString(" ") + envBuilder.addVariables( + Variable.newBuilder().setName("SPARK_EXECUTOR_OPTS").setValue(executorOpts)) + val cmdOptions = generateCmdOption(desc) + val executorUri = desc.schedulerProperties.get("spark.executor.uri") + .orElse(desc.command.environment.get("SPARK_EXECUTOR_URI")) + val appArguments = desc.command.arguments.mkString(" ") + val cmd = if (executorUri.isDefined) { + builder.addUris(CommandInfo.URI.newBuilder().setValue(executorUri.get).build()) + val folderBasename = executorUri.get.split('/').last.split('.').head + val cmdExecutable = s"cd $folderBasename*; $prefixEnv bin/spark-submit" + val cmdJar = s"../${desc.jarUrl.split("/").last}" + s"$cmdExecutable ${cmdOptions.mkString(" ")} $cmdJar $appArguments" + } else { + val executorSparkHome = desc.schedulerProperties.get("spark.mesos.executor.home") + .orElse(conf.getOption("spark.home")) + .orElse(Option(System.getenv("SPARK_HOME"))) + .getOrElse { + throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") + } + val cmdExecutable = new File(executorSparkHome, "./bin/spark-submit").getCanonicalPath + val cmdJar = desc.jarUrl.split("/").last + s"$cmdExecutable ${cmdOptions.mkString(" ")} $cmdJar $appArguments" + } + builder.setValue(cmd) + builder.setEnvironment(envBuilder.build()) + builder.build() + } + + private def generateCmdOption(desc: MesosDriverDescription): Seq[String] = { + var options = Seq( + "--name", desc.schedulerProperties("spark.app.name"), + "--class", desc.command.mainClass, + "--master", s"mesos://${conf.get("spark.master")}", + "--driver-cores", desc.cores.toString, + "--driver-memory", s"${desc.mem}M") + desc.schedulerProperties.get("spark.executor.memory").map { v => + options ++= Seq("--executor-memory", v) + } + desc.schedulerProperties.get("spark.cores.max").map { v => + options ++= Seq("--total-executor-cores", v) + } + options + } + + private class ResourceOffer(val offer: Offer, var cpu: Double, var mem: Double) { + override def toString(): String = { + s"Offer id: ${offer.getId.getValue}, cpu: $cpu, mem: $mem" + } + } + + /** + * This method takes all the possible candidates and attempt to schedule them with Mesos offers. + * Every time a new task is scheduled, the afterLaunchCallback is called to perform post scheduled + * logic on each task. + */ + private def scheduleTasks( + candidates: Seq[MesosDriverDescription], + afterLaunchCallback: (String) => Boolean, + currentOffers: List[ResourceOffer], + tasks: mutable.HashMap[OfferID, ArrayBuffer[TaskInfo]]): Unit = { + for (submission <- candidates) { + val driverCpu = submission.cores + val driverMem = submission.mem + logTrace(s"Finding offer to launch driver with cpu: $driverCpu, mem: $driverMem") + val offerOption = currentOffers.find { o => + o.cpu >= driverCpu && o.mem >= driverMem + } + if (offerOption.isEmpty) { + logDebug(s"Unable to find offer to launch driver id: ${submission.submissionId}, " + + s"cpu: $driverCpu, mem: $driverMem") + } else { + val offer = offerOption.get + offer.cpu -= driverCpu + offer.mem -= driverMem + val taskId = TaskID.newBuilder().setValue(submission.submissionId).build() + val cpuResource = Resource.newBuilder() + .setName("cpus").setType(Value.Type.SCALAR) + .setScalar(Value.Scalar.newBuilder().setValue(driverCpu)).build() + val memResource = Resource.newBuilder() + .setName("mem").setType(Value.Type.SCALAR) + .setScalar(Value.Scalar.newBuilder().setValue(driverMem)).build() + val commandInfo = buildDriverCommand(submission) + val appName = submission.schedulerProperties("spark.app.name") + val taskInfo = TaskInfo.newBuilder() + .setTaskId(taskId) + .setName(s"Driver for $appName") + .setSlaveId(offer.offer.getSlaveId) + .setCommand(commandInfo) + .addResources(cpuResource) + .addResources(memResource) + .build() + val queuedTasks = tasks.getOrElseUpdate(offer.offer.getId, new ArrayBuffer[TaskInfo]) + queuedTasks += taskInfo + logTrace(s"Using offer ${offer.offer.getId.getValue} to launch driver " + + submission.submissionId) + val newState = new MesosClusterSubmissionState(submission, taskId, offer.offer.getSlaveId, + None, new Date()) + launchedDrivers(submission.submissionId) = newState + launchedDriversState.persist(submission.submissionId, newState) + afterLaunchCallback(submission.submissionId) + } + } + } + + override def resourceOffers(driver: SchedulerDriver, offers: JList[Offer]): Unit = { + val currentOffers = offers.map { o => + new ResourceOffer( + o, getResource(o.getResourcesList, "cpus"), getResource(o.getResourcesList, "mem")) + }.toList + logTrace(s"Received offers from Mesos: \n${currentOffers.mkString("\n")}") + val tasks = new mutable.HashMap[OfferID, ArrayBuffer[TaskInfo]]() + val currentTime = new Date() + + stateLock.synchronized { + // We first schedule all the supervised drivers that are ready to retry. + // This list will be empty if none of the drivers are marked as supervise. + val driversToRetry = pendingRetryDrivers.filter { d => + d.retryState.get.nextRetry.before(currentTime) + } + scheduleTasks( + driversToRetry, + removeFromPendingRetryDrivers, + currentOffers, + tasks) + // Then we walk through the queued drivers and try to schedule them. + scheduleTasks( + queuedDrivers, + removeFromQueuedDrivers, + currentOffers, + tasks) + } + tasks.foreach { case (offerId, tasks) => + driver.launchTasks(Collections.singleton(offerId), tasks) + } + offers + .filter(o => !tasks.keySet.contains(o.getId)) + .foreach(o => driver.declineOffer(o.getId)) + } + + def getSchedulerState(): MesosClusterSchedulerState = { + def copyBuffer( + buffer: ArrayBuffer[MesosDriverDescription]): ArrayBuffer[MesosDriverDescription] = { + val newBuffer = new ArrayBuffer[MesosDriverDescription](buffer.size) + buffer.copyToBuffer(newBuffer) + newBuffer + } + stateLock.synchronized { + new MesosClusterSchedulerState( + frameworkId, + masterInfo.map(m => s"http://${m.getIp}:${m.getPort}"), + copyBuffer(queuedDrivers), + launchedDrivers.values.map(_.copy()).toList, + finishedDrivers.map(_.copy()).toList, + copyBuffer(pendingRetryDrivers)) + } + } + + override def offerRescinded(driver: SchedulerDriver, offerId: OfferID): Unit = {} + override def disconnected(driver: SchedulerDriver): Unit = {} + override def reregistered(driver: SchedulerDriver, masterInfo: MasterInfo): Unit = { + logInfo(s"Framework re-registered with master ${masterInfo.getId}") + } + override def slaveLost(driver: SchedulerDriver, slaveId: SlaveID): Unit = {} + override def error(driver: SchedulerDriver, error: String): Unit = { + logError("Error received: " + error) + } + + /** + * Check if the task state is a recoverable state that we can relaunch the task. + * Task state like TASK_ERROR are not relaunchable state since it wasn't able + * to be validated by Mesos. + */ + private def shouldRelaunch(state: MesosTaskState): Boolean = { + state == MesosTaskState.TASK_FAILED || + state == MesosTaskState.TASK_KILLED || + state == MesosTaskState.TASK_LOST + } + + override def statusUpdate(driver: SchedulerDriver, status: TaskStatus): Unit = { + val taskId = status.getTaskId.getValue + stateLock.synchronized { + if (launchedDrivers.contains(taskId)) { + if (status.getReason == Reason.REASON_RECONCILIATION && + !pendingRecover.contains(taskId)) { + // Task has already received update and no longer requires reconciliation. + return + } + val state = launchedDrivers(taskId) + // Check if the driver is supervise enabled and can be relaunched. + if (state.driverDescription.supervise && shouldRelaunch(status.getState)) { + removeFromLaunchedDrivers(taskId) + val retryState: Option[MesosClusterRetryState] = state.driverDescription.retryState + val (retries, waitTimeSec) = retryState + .map { rs => (rs.retries + 1, Math.min(maxRetryWaitTime, rs.waitTime * 2)) } + .getOrElse{ (1, 1) } + val nextRetry = new Date(new Date().getTime + waitTimeSec * 1000L) + + val newDriverDescription = state.driverDescription.copy( + retryState = Some(new MesosClusterRetryState(status, retries, nextRetry, waitTimeSec))) + pendingRetryDrivers += newDriverDescription + pendingRetryDriversState.persist(taskId, newDriverDescription) + } else if (TaskState.isFinished(TaskState.fromMesos(status.getState))) { + removeFromLaunchedDrivers(taskId) + if (finishedDrivers.size >= retainedDrivers) { + val toRemove = math.max(retainedDrivers / 10, 1) + finishedDrivers.trimStart(toRemove) + } + finishedDrivers += state + } + state.mesosTaskStatus = Option(status) + } else { + logError(s"Unable to find driver $taskId in status update") + } + } + } + + override def frameworkMessage( + driver: SchedulerDriver, + executorId: ExecutorID, + slaveId: SlaveID, + message: Array[Byte]): Unit = {} + + override def executorLost( + driver: SchedulerDriver, + executorId: ExecutorID, + slaveId: SlaveID, + status: Int): Unit = {} + + private def removeFromQueuedDrivers(id: String): Boolean = { + val index = queuedDrivers.indexWhere(_.submissionId.equals(id)) + if (index != -1) { + queuedDrivers.remove(index) + queuedDriversState.expunge(id) + true + } else { + false + } + } + + private def removeFromLaunchedDrivers(id: String): Boolean = { + if (launchedDrivers.remove(id).isDefined) { + launchedDriversState.expunge(id) + true + } else { + false + } + } + + private def removeFromPendingRetryDrivers(id: String): Boolean = { + val index = pendingRetryDrivers.indexWhere(_.submissionId.equals(id)) + if (index != -1) { + pendingRetryDrivers.remove(index) + pendingRetryDriversState.expunge(id) + true + } else { + false + } + } + + def getQueuedDriversSize: Int = queuedDrivers.size + def getLaunchedDriversSize: Int = launchedDrivers.size + def getPendingRetryDriversSize: Int = pendingRetryDrivers.size +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSource.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSource.scala new file mode 100644 index 0000000000000..1fe94974c8e36 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSource.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import com.codahale.metrics.{Gauge, MetricRegistry} + +import org.apache.spark.metrics.source.Source + +private[mesos] class MesosClusterSchedulerSource(scheduler: MesosClusterScheduler) + extends Source { + override def sourceName: String = "mesos_cluster" + override def metricRegistry: MetricRegistry = new MetricRegistry() + + metricRegistry.register(MetricRegistry.name("waitingDrivers"), new Gauge[Int] { + override def getValue: Int = scheduler.getQueuedDriversSize + }) + + metricRegistry.register(MetricRegistry.name("launchedDrivers"), new Gauge[Int] { + override def getValue: Int = scheduler.getLaunchedDriversSize + }) + + metricRegistry.register(MetricRegistry.name("retryDrivers"), new Gauge[Int] { + override def getValue: Int = scheduler.getPendingRetryDriversSize + }) +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index d9d62b0e287ed..8346a2407489f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -18,23 +18,19 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File -import java.util.{ArrayList => JArrayList, List => JList} -import java.util.Collections +import java.util.{ArrayList => JArrayList, Collections, List => JList} import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet} +import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} import org.apache.mesos.protobuf.ByteString -import org.apache.mesos.{Scheduler => MScheduler} -import org.apache.mesos._ -import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, - ExecutorInfo => MesosExecutorInfo, _} - +import org.apache.mesos.{Scheduler => MScheduler, _} import org.apache.spark.executor.MesosExecutorBackend -import org.apache.spark.{Logging, SparkContext, SparkException, TaskState} -import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.util.Utils +import org.apache.spark.{SparkContext, SparkException, TaskState} /** * A SchedulerBackend for running fine-grained tasks on Mesos. Each Spark task is mapped to a @@ -47,14 +43,7 @@ private[spark] class MesosSchedulerBackend( master: String) extends SchedulerBackend with MScheduler - with Logging { - - // Lock used to wait for scheduler to be registered - var isRegistered = false - val registeredLock = new Object() - - // Driver for talking to Mesos - var driver: SchedulerDriver = null + with MesosSchedulerUtils { // Which slave IDs we have executors on val slaveIdsWithExecutors = new HashSet[String] @@ -73,26 +62,9 @@ private[spark] class MesosSchedulerBackend( @volatile var appId: String = _ override def start() { - synchronized { - classLoader = Thread.currentThread.getContextClassLoader - - new Thread("MesosSchedulerBackend driver") { - setDaemon(true) - override def run() { - val scheduler = MesosSchedulerBackend.this - val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() - driver = new MesosSchedulerDriver(scheduler, fwInfo, master) - try { - val ret = driver.run() - logInfo("driver.run() returned with code " + ret) - } catch { - case e: Exception => logError("driver.run() failed", e) - } - } - }.start() - - waitForRegister() - } + val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() + classLoader = Thread.currentThread.getContextClassLoader + startScheduler(master, MesosSchedulerBackend.this, fwInfo) } def createExecutorInfo(execId: String): MesosExecutorInfo = { @@ -125,17 +97,19 @@ private[spark] class MesosSchedulerBackend( } val command = CommandInfo.newBuilder() .setEnvironment(environment) - val uri = sc.conf.get("spark.executor.uri", null) + val uri = sc.conf.getOption("spark.executor.uri") + .orElse(Option(System.getenv("SPARK_EXECUTOR_URI"))) + val executorBackendName = classOf[MesosExecutorBackend].getName - if (uri == null) { + if (uri.isEmpty) { val executorPath = new File(executorSparkHome, "/bin/spark-class").getCanonicalPath command.setValue(s"$prefixEnv $executorPath $executorBackendName") } else { // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". - val basename = uri.split('/').last.split('.').head + val basename = uri.get.split('/').last.split('.').head command.setValue(s"cd ${basename}*; $prefixEnv ./bin/spark-class $executorBackendName") - command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) + command.addUris(CommandInfo.URI.newBuilder().setValue(uri.get)) } val cpus = Resource.newBuilder() .setName("cpus") @@ -181,18 +155,7 @@ private[spark] class MesosSchedulerBackend( inClassLoader() { appId = frameworkId.getValue logInfo("Registered as framework ID " + appId) - registeredLock.synchronized { - isRegistered = true - registeredLock.notifyAll() - } - } - } - - def waitForRegister() { - registeredLock.synchronized { - while (!isRegistered) { - registeredLock.wait() - } + markRegistered() } } @@ -287,14 +250,6 @@ private[spark] class MesosSchedulerBackend( } } - /** Helper function to pull out a resource from a Mesos Resources protobuf */ - def getResource(res: JList[Resource], name: String): Double = { - for (r <- res if r.getName == name) { - return r.getScalar.getValue - } - 0 - } - /** Turn a Spark TaskDescription into a Mesos task */ def createMesosTask(task: TaskDescription, slaveId: String): MesosTaskInfo = { val taskId = TaskID.newBuilder().setValue(task.taskId.toString).build() @@ -339,13 +294,13 @@ private[spark] class MesosSchedulerBackend( } override def stop() { - if (driver != null) { - driver.stop() + if (mesosDriver != null) { + mesosDriver.stop() } } override def reviveOffers() { - driver.reviveOffers() + mesosDriver.reviveOffers() } override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {} @@ -380,7 +335,7 @@ private[spark] class MesosSchedulerBackend( } override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = { - driver.killTask( + mesosDriver.killTask( TaskID.newBuilder() .setValue(taskId.toString).build() ) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala new file mode 100644 index 0000000000000..d11228f3d016a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.mesos + +import java.util.List +import java.util.concurrent.CountDownLatch + +import scala.collection.JavaConversions._ + +import org.apache.mesos.Protos.{FrameworkInfo, Resource, Status} +import org.apache.mesos.{MesosSchedulerDriver, Scheduler} +import org.apache.spark.Logging +import org.apache.spark.util.Utils + +/** + * Shared trait for implementing a Mesos Scheduler. This holds common state and helper + * methods and Mesos scheduler will use. + */ +private[mesos] trait MesosSchedulerUtils extends Logging { + // Lock used to wait for scheduler to be registered + private final val registerLatch = new CountDownLatch(1) + + // Driver for talking to Mesos + protected var mesosDriver: MesosSchedulerDriver = null + + /** + * Starts the MesosSchedulerDriver with the provided information. This method returns + * only after the scheduler has registered with Mesos. + * @param masterUrl Mesos master connection URL + * @param scheduler Scheduler object + * @param fwInfo FrameworkInfo to pass to the Mesos master + */ + def startScheduler(masterUrl: String, scheduler: Scheduler, fwInfo: FrameworkInfo): Unit = { + synchronized { + if (mesosDriver != null) { + registerLatch.await() + return + } + + new Thread(Utils.getFormattedClassName(this) + "-mesos-driver") { + setDaemon(true) + + override def run() { + mesosDriver = new MesosSchedulerDriver(scheduler, fwInfo, masterUrl) + try { + val ret = mesosDriver.run() + logInfo("driver.run() returned with code " + ret) + if (ret.equals(Status.DRIVER_ABORTED)) { + System.exit(1) + } + } catch { + case e: Exception => { + logError("driver.run() failed", e) + System.exit(1) + } + } + } + }.start() + + registerLatch.await() + } + } + + /** + * Signal that the scheduler has registered with Mesos. + */ + protected def markRegistered(): Unit = { + registerLatch.countDown() + } + + /** + * Get the amount of resources for the specified type from the resource list + */ + protected def getResource(res: List[Resource], name: String): Double = { + for (r <- res if r.getName == name) { + return r.getScalar.getValue + } + 0.0 + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index ac5b524517818..e64d06c4d3cfc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -123,7 +123,7 @@ private[spark] class LocalBackend( } override def stop() { - localEndpoint.sendWithReply(StopExecutor) + localEndpoint.ask(StopExecutor) } override def reviveOffers() { diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 579fb6624e692..754832b8a4ca7 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -49,16 +49,17 @@ class KryoSerializer(conf: SparkConf) with Logging with Serializable { - private val bufferSizeMb = conf.getDouble("spark.kryoserializer.buffer.mb", 0.064) - if (bufferSizeMb >= 2048) { - throw new IllegalArgumentException("spark.kryoserializer.buffer.mb must be less than " + - s"2048 mb, got: + $bufferSizeMb mb.") + private val bufferSizeKb = conf.getSizeAsKb("spark.kryoserializer.buffer", "64k") + + if (bufferSizeKb >= 2048) { + throw new IllegalArgumentException("spark.kryoserializer.buffer must be less than " + + s"2048 mb, got: + $bufferSizeKb mb.") } - private val bufferSize = (bufferSizeMb * 1024 * 1024).toInt + private val bufferSize = (bufferSizeKb * 1024).toInt - val maxBufferSizeMb = conf.getInt("spark.kryoserializer.buffer.max.mb", 64) + val maxBufferSizeMb = conf.getSizeAsMb("spark.kryoserializer.buffer.max", "64m").toInt if (maxBufferSizeMb >= 2048) { - throw new IllegalArgumentException("spark.kryoserializer.buffer.max.mb must be less than " + + throw new IllegalArgumentException("spark.kryoserializer.buffer.max must be less than " + s"2048 mb, got: + $maxBufferSizeMb mb.") } private val maxBufferSize = maxBufferSizeMb * 1024 * 1024 @@ -173,7 +174,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ } catch { case e: KryoException if e.getMessage.startsWith("Buffer overflow") => throw new SparkException(s"Kryo serialization failed: ${e.getMessage}. To avoid this, " + - "increase spark.kryoserializer.buffer.max.mb value.") + "increase spark.kryoserializer.buffer.max value.") } ByteBuffer.wrap(output.toBytes) } diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala index cecb992579655..5abfa467c0ec8 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala @@ -23,6 +23,7 @@ import java.security.AccessController import scala.annotation.tailrec import scala.collection.mutable +import scala.util.control.NonFatal import org.apache.spark.Logging @@ -35,8 +36,15 @@ private[serializer] object SerializationDebugger extends Logging { */ def improveException(obj: Any, e: NotSerializableException): NotSerializableException = { if (enableDebugging && reflect != null) { - new NotSerializableException( - e.getMessage + "\nSerialization stack:\n" + find(obj).map("\t- " + _).mkString("\n")) + try { + new NotSerializableException( + e.getMessage + "\nSerialization stack:\n" + find(obj).map("\t- " + _).mkString("\n")) + } catch { + case NonFatal(t) => + // Fall back to old exception + logWarning("Exception in serialization debugger", t) + e + } } else { e } diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala index 538e150ead05a..e9b4e2b955dc8 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala @@ -78,7 +78,8 @@ class FileShuffleBlockManager(conf: SparkConf) private val consolidateShuffleFiles = conf.getBoolean("spark.shuffle.consolidateFiles", false) - private val bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 + // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided + private val bufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 /** * Contains all the state related to a particular shuffle. This includes a pool of unused diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 7a2c5ae32d98b..80374adc44296 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -79,7 +79,8 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { blockManager, blocksByAddress, serializer, - SparkEnv.get.conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024) + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) val itr = blockFetcherItr.flatMap(unpackBlock) val completionIter = CompletionIterator[T, Iterator[T]](itr, { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index c798843bd5d8a..9bfc4201d37c0 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -55,7 +55,7 @@ class BlockManagerMaster( memSize: Long, diskSize: Long, tachyonSize: Long): Boolean = { - val res = driverEndpoint.askWithReply[Boolean]( + val res = driverEndpoint.askWithRetry[Boolean]( UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize, tachyonSize)) logDebug(s"Updated info of block $blockId") res @@ -63,12 +63,12 @@ class BlockManagerMaster( /** Get locations of the blockId from the driver */ def getLocations(blockId: BlockId): Seq[BlockManagerId] = { - driverEndpoint.askWithReply[Seq[BlockManagerId]](GetLocations(blockId)) + driverEndpoint.askWithRetry[Seq[BlockManagerId]](GetLocations(blockId)) } /** Get locations of multiple blockIds from the driver */ def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { - driverEndpoint.askWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) + driverEndpoint.askWithRetry[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) } /** @@ -81,11 +81,11 @@ class BlockManagerMaster( /** Get ids of other nodes in the cluster from the driver */ def getPeers(blockManagerId: BlockManagerId): Seq[BlockManagerId] = { - driverEndpoint.askWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId)) + driverEndpoint.askWithRetry[Seq[BlockManagerId]](GetPeers(blockManagerId)) } def getRpcHostPortForExecutor(executorId: String): Option[(String, Int)] = { - driverEndpoint.askWithReply[Option[(String, Int)]](GetRpcHostPortForExecutor(executorId)) + driverEndpoint.askWithRetry[Option[(String, Int)]](GetRpcHostPortForExecutor(executorId)) } /** @@ -93,12 +93,12 @@ class BlockManagerMaster( * blocks that the driver knows about. */ def removeBlock(blockId: BlockId) { - driverEndpoint.askWithReply[Boolean](RemoveBlock(blockId)) + driverEndpoint.askWithRetry[Boolean](RemoveBlock(blockId)) } /** Remove all blocks belonging to the given RDD. */ def removeRdd(rddId: Int, blocking: Boolean) { - val future = driverEndpoint.askWithReply[Future[Seq[Int]]](RemoveRdd(rddId)) + val future = driverEndpoint.askWithRetry[Future[Seq[Int]]](RemoveRdd(rddId)) future.onFailure { case e: Exception => logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}}") @@ -110,7 +110,7 @@ class BlockManagerMaster( /** Remove all blocks belonging to the given shuffle. */ def removeShuffle(shuffleId: Int, blocking: Boolean) { - val future = driverEndpoint.askWithReply[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) + val future = driverEndpoint.askWithRetry[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) future.onFailure { case e: Exception => logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}}") @@ -122,7 +122,7 @@ class BlockManagerMaster( /** Remove all blocks belonging to the given broadcast. */ def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean, blocking: Boolean) { - val future = driverEndpoint.askWithReply[Future[Seq[Int]]]( + val future = driverEndpoint.askWithRetry[Future[Seq[Int]]]( RemoveBroadcast(broadcastId, removeFromMaster)) future.onFailure { case e: Exception => @@ -141,11 +141,11 @@ class BlockManagerMaster( * amount of remaining memory. */ def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = { - driverEndpoint.askWithReply[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) + driverEndpoint.askWithRetry[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) } def getStorageStatus: Array[StorageStatus] = { - driverEndpoint.askWithReply[Array[StorageStatus]](GetStorageStatus) + driverEndpoint.askWithRetry[Array[StorageStatus]](GetStorageStatus) } /** @@ -166,7 +166,7 @@ class BlockManagerMaster( * master endpoint for a response to a prior message. */ val response = driverEndpoint. - askWithReply[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg) + askWithRetry[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg) val (blockManagerIds, futures) = response.unzip val result = Await.result(Future.sequence(futures), timeout) if (result == null) { @@ -190,7 +190,7 @@ class BlockManagerMaster( filter: BlockId => Boolean, askSlaves: Boolean): Seq[BlockId] = { val msg = GetMatchingBlockIds(filter, askSlaves) - val future = driverEndpoint.askWithReply[Future[Seq[BlockId]]](msg) + val future = driverEndpoint.askWithRetry[Future[Seq[BlockId]]](msg) Await.result(future, timeout) } @@ -205,7 +205,7 @@ class BlockManagerMaster( /** Send a one-way message to the master endpoint, to which we expect it to reply with true. */ private def tell(message: Any) { - if (!driverEndpoint.askWithReply[Boolean](message)) { + if (!driverEndpoint.askWithRetry[Boolean](message)) { throw new SparkException("BlockManagerMasterEndpoint returned false, expected true.") } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 4682167912ff0..7212362df5d71 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -132,7 +132,7 @@ class BlockManagerMasterEndpoint( val removeMsg = RemoveRdd(rddId) Future.sequence( blockManagerInfo.values.map { bm => - bm.slaveEndpoint.sendWithReply[Int](removeMsg) + bm.slaveEndpoint.ask[Int](removeMsg) }.toSeq ) } @@ -142,7 +142,7 @@ class BlockManagerMasterEndpoint( val removeMsg = RemoveShuffle(shuffleId) Future.sequence( blockManagerInfo.values.map { bm => - bm.slaveEndpoint.sendWithReply[Boolean](removeMsg) + bm.slaveEndpoint.ask[Boolean](removeMsg) }.toSeq ) } @@ -159,7 +159,7 @@ class BlockManagerMasterEndpoint( } Future.sequence( requiredBlockManagers.map { bm => - bm.slaveEndpoint.sendWithReply[Int](removeMsg) + bm.slaveEndpoint.ask[Int](removeMsg) }.toSeq ) } @@ -214,7 +214,7 @@ class BlockManagerMasterEndpoint( // Remove the block from the slave's BlockManager. // Doesn't actually wait for a confirmation and the message might get lost. // If message loss becomes frequent, we should add retry logic here. - blockManager.get.slaveEndpoint.sendWithReply[Boolean](RemoveBlock(blockId)) + blockManager.get.slaveEndpoint.ask[Boolean](RemoveBlock(blockId)) } } } @@ -253,7 +253,7 @@ class BlockManagerMasterEndpoint( blockManagerInfo.values.map { info => val blockStatusFuture = if (askSlaves) { - info.slaveEndpoint.sendWithReply[Option[BlockStatus]](getBlockStatus) + info.slaveEndpoint.ask[Option[BlockStatus]](getBlockStatus) } else { Future { info.getStatus(blockId) } } @@ -277,7 +277,7 @@ class BlockManagerMasterEndpoint( blockManagerInfo.values.map { info => val future = if (askSlaves) { - info.slaveEndpoint.sendWithReply[Seq[BlockId]](getMatchingBlockIds) + info.slaveEndpoint.ask[Seq[BlockId]](getMatchingBlockIds) } else { Future { info.blocks.keys.filter(filter).toSeq } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 4b232ae7d3180..1f45956282166 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -31,8 +31,7 @@ import org.apache.spark.util.Utils private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManager) extends BlockStore(blockManager) with Logging { - val minMemoryMapBytes = blockManager.conf.getLong( - "spark.storage.memoryMapThreshold", 2 * 1024L * 1024L) + val minMemoryMapBytes = blockManager.conf.getSizeAsBytes("spark.storage.memoryMapThreshold", "2m") override def getSize(blockId: BlockId): Long = { diskManager.getFile(blockId.name).length diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index 26ffbf9350388..4dd7ab9e0767b 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -179,7 +179,7 @@ private[spark] object SizeEstimator extends Logging { } // Estimate the size of arrays larger than ARRAY_SIZE_FOR_SAMPLING by sampling. - private val ARRAY_SIZE_FOR_SAMPLING = 200 + private val ARRAY_SIZE_FOR_SAMPLING = 400 private val ARRAY_SAMPLE_SIZE = 100 // should be lower than ARRAY_SIZE_FOR_SAMPLING private def visitArray(array: AnyRef, arrayClass: Class[_], state: SearchState) { @@ -204,25 +204,40 @@ private[spark] object SizeEstimator extends Logging { } } else { // Estimate the size of a large array by sampling elements without replacement. - var size = 0.0 + // To exclude the shared objects that the array elements may link, sample twice + // and use the min one to caculate array size. val rand = new Random(42) - val drawn = new OpenHashSet[Int](ARRAY_SAMPLE_SIZE) - var numElementsDrawn = 0 - while (numElementsDrawn < ARRAY_SAMPLE_SIZE) { - var index = 0 - do { - index = rand.nextInt(length) - } while (drawn.contains(index)) - drawn.add(index) - val elem = ScalaRunTime.array_apply(array, index).asInstanceOf[AnyRef] - size += SizeEstimator.estimate(elem, state.visited) - numElementsDrawn += 1 - } - state.size += ((length / (ARRAY_SAMPLE_SIZE * 1.0)) * size).toLong + val drawn = new OpenHashSet[Int](2 * ARRAY_SAMPLE_SIZE) + val s1 = sampleArray(array, state, rand, drawn, length) + val s2 = sampleArray(array, state, rand, drawn, length) + val size = math.min(s1, s2) + state.size += math.max(s1, s2) + + (size * ((length - ARRAY_SAMPLE_SIZE) / (ARRAY_SAMPLE_SIZE))).toLong } } } + private def sampleArray( + array: AnyRef, + state: SearchState, + rand: Random, + drawn: OpenHashSet[Int], + length: Int): Long = { + var size = 0L + for (i <- 0 until ARRAY_SAMPLE_SIZE) { + var index = 0 + do { + index = rand.nextInt(length) + } while (drawn.contains(index)) + drawn.add(index) + val obj = ScalaRunTime.array_apply(array, index).asInstanceOf[AnyRef] + if (obj != null) { + size += SizeEstimator.estimate(obj, state.visited).toLong + } + } + size + } + private def primitiveSize(cls: Class[_]): Long = { if (cls == classOf[Byte]) { BYTE_SIZE diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index c6c6df7cfa56e..4b5a5df5ef7b7 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -67,6 +67,12 @@ private[spark] object Utils extends Logging { val DEFAULT_SHUTDOWN_PRIORITY = 100 + /** + * The shutdown priority of the SparkContext instance. This is lower than the default + * priority, so that by default hooks are run before the context is shut down. + */ + val SPARK_CONTEXT_SHUTDOWN_PRIORITY = 50 + private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 @volatile private var localRootDirs: Array[String] = null @@ -1014,21 +1020,48 @@ private[spark] object Utils extends Logging { } /** - * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes. + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to bytes for internal use. + * + * If no suffix is provided, the passed number is assumed to be in bytes. + */ + def byteStringAsBytes(str: String): Long = { + JavaUtils.byteStringAsBytes(str) + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to kibibytes for internal use. + * + * If no suffix is provided, the passed number is assumed to be in kibibytes. + */ + def byteStringAsKb(str: String): Long = { + JavaUtils.byteStringAsKb(str) + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to mebibytes for internal use. + * + * If no suffix is provided, the passed number is assumed to be in mebibytes. + */ + def byteStringAsMb(str: String): Long = { + JavaUtils.byteStringAsMb(str) + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m, 500g) to gibibytes for internal use. + * + * If no suffix is provided, the passed number is assumed to be in gibibytes. + */ + def byteStringAsGb(str: String): Long = { + JavaUtils.byteStringAsGb(str) + } + + /** + * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of mebibytes. */ def memoryStringToMb(str: String): Int = { - val lower = str.toLowerCase - if (lower.endsWith("k")) { - (lower.substring(0, lower.length-1).toLong / 1024).toInt - } else if (lower.endsWith("m")) { - lower.substring(0, lower.length-1).toInt - } else if (lower.endsWith("g")) { - lower.substring(0, lower.length-1).toInt * 1024 - } else if (lower.endsWith("t")) { - lower.substring(0, lower.length-1).toInt * 1024 * 1024 - } else {// no suffix, so it's just a number in bytes - (lower.toLong / 1024 / 1024).toInt - } + // Convert to bytes, rather than directly to MB, because when no units are specified the unit + // is assumed to be bytes + (JavaUtils.byteStringAsBytes(str) / 1024 / 1024).toInt } /** @@ -1266,16 +1299,18 @@ private[spark] object Utils extends Logging { } /** Default filtering function for finding call sites using `getCallSite`. */ - private def coreExclusionFunction(className: String): Boolean = { - // A regular expression to match classes of the "core" Spark API that we want to skip when - // finding the call site of a method. + private def sparkInternalExclusionFunction(className: String): Boolean = { + // A regular expression to match classes of the internal Spark API's + // that we want to skip when finding the call site of a method. val SPARK_CORE_CLASS_REGEX = """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?(\.broadcast)?\.[A-Z]""".r + val SPARK_SQL_CLASS_REGEX = """^org\.apache\.spark\.sql.*""".r val SCALA_CORE_CLASS_PREFIX = "scala" - val isSparkCoreClass = SPARK_CORE_CLASS_REGEX.findFirstIn(className).isDefined + val isSparkClass = SPARK_CORE_CLASS_REGEX.findFirstIn(className).isDefined || + SPARK_SQL_CLASS_REGEX.findFirstIn(className).isDefined val isScalaClass = className.startsWith(SCALA_CORE_CLASS_PREFIX) // If the class is a Spark internal class or a Scala class, then exclude. - isSparkCoreClass || isScalaClass + isSparkClass || isScalaClass } /** @@ -1285,7 +1320,7 @@ private[spark] object Utils extends Logging { * * @param skipClass Function that is used to exclude non-user-code classes. */ - def getCallSite(skipClass: String => Boolean = coreExclusionFunction): CallSite = { + def getCallSite(skipClass: String => Boolean = sparkInternalExclusionFunction): CallSite = { // Keep crawling up the stack trace until we find the first function not inside of the spark // package. We track the last (shallowest) contiguous Spark method. This might be an RDD // transformation, a SparkContext function (such as parallelize), or anything else that leads @@ -1324,9 +1359,17 @@ private[spark] object Utils extends Logging { } val callStackDepth = System.getProperty("spark.callstack.depth", "20").toInt - CallSite( - shortForm = s"$lastSparkMethod at $firstUserFile:$firstUserLine", - longForm = callStack.take(callStackDepth).mkString("\n")) + val shortForm = + if (firstUserFile == "HiveSessionImpl.java") { + // To be more user friendly, show a nicer string for queries submitted from the JDBC + // server. + "Spark JDBC Server Query" + } else { + s"$lastSparkMethod at $firstUserFile:$firstUserLine" + } + val longForm = callStack.take(callStackDepth).mkString("\n") + + CallSite(shortForm, longForm) } /** Return a string containing part of a file from byte 'start' to 'end'. */ @@ -2116,7 +2159,7 @@ private[spark] object Utils extends Logging { * @return A handle that can be used to unregister the shutdown hook. */ def addShutdownHook(hook: () => Unit): AnyRef = { - addShutdownHook(DEFAULT_SHUTDOWN_PRIORITY, hook) + addShutdownHook(DEFAULT_SHUTDOWN_PRIORITY)(hook) } /** @@ -2126,7 +2169,7 @@ private[spark] object Utils extends Logging { * @param hook The code to run during shutdown. * @return A handle that can be used to unregister the shutdown hook. */ - def addShutdownHook(priority: Int, hook: () => Unit): AnyRef = { + def addShutdownHook(priority: Int)(hook: () => Unit): AnyRef = { shutdownHooks.add(priority, hook) } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 30dd7f22e494f..f912049563906 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -89,8 +89,10 @@ class ExternalAppendOnlyMap[K, V, C]( // Number of bytes spilled in total private var _diskBytesSpilled = 0L - - private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 + + // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided + private val fileBufferSize = + sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 // Write metrics for current spill private var curWriteMetrics: ShuffleWriteMetrics = _ diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 79a695fb62086..ef3cac622505e 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -108,7 +108,9 @@ private[spark] class ExternalSorter[K, V, C]( private val conf = SparkEnv.get.conf private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true) - private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 + + // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided + private val fileBufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 private val transferToEnabled = conf.getBoolean("spark.file.transferTo", true) // Size of object batches when reading/writing from serializers. diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala index e579421676343..7138b4b8e4533 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala @@ -138,7 +138,7 @@ private[spark] object RollingFileAppender { val STRATEGY_DEFAULT = "" val INTERVAL_PROPERTY = "spark.executor.logs.rolling.time.interval" val INTERVAL_DEFAULT = "daily" - val SIZE_PROPERTY = "spark.executor.logs.rolling.size.maxBytes" + val SIZE_PROPERTY = "spark.executor.logs.rolling.maxSize" val SIZE_DEFAULT = (1024 * 1024).toString val RETAINED_FILES_PROPERTY = "spark.executor.logs.rolling.maxRetainedFiles" val DEFAULT_BUFFER_SIZE = 8192 diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 8a4f2a08fe701..34ac9361d46c6 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -1009,7 +1009,7 @@ public void persist() { @Test public void iterator() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); - TaskContext context = new TaskContextImpl(0, 0, 0L, 0, false, new TaskMetrics()); + TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, false, new TaskMetrics()); Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue()); } diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index 70529d9216591..668ddf9f5f0a9 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -65,7 +65,7 @@ class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAf // in blockManager.put is a losing battle. You have been warned. blockManager = sc.env.blockManager cacheManager = sc.env.cacheManager - val context = new TaskContextImpl(0, 0, 0, 0) + val context = new TaskContextImpl(0, 0, 0, 0, null) val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) val getValue = blockManager.get(RDDBlockId(rdd.id, split.index)) assert(computeValue.toList === List(1, 2, 3, 4)) @@ -77,7 +77,7 @@ class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAf val result = new BlockResult(Array(5, 6, 7).iterator, DataReadMethod.Memory, 12) when(blockManager.get(RDDBlockId(0, 0))).thenReturn(Some(result)) - val context = new TaskContextImpl(0, 0, 0, 0) + val context = new TaskContextImpl(0, 0, 0, 0, null) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(5, 6, 7)) } @@ -86,14 +86,14 @@ class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAf // Local computation should not persist the resulting value, so don't expect a put(). when(blockManager.get(RDDBlockId(0, 0))).thenReturn(None) - val context = new TaskContextImpl(0, 0, 0, 0, true) + val context = new TaskContextImpl(0, 0, 0, 0, null, true) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(1, 2, 3, 4)) } test("verify task metrics updated correctly") { cacheManager = sc.env.cacheManager - val context = new TaskContextImpl(0, 0, 0, 0) + val context = new TaskContextImpl(0, 0, 0, 0, null) cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY) assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2) } diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 97ea3578aa8ba..96a9c207ad022 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -77,7 +77,7 @@ class DistributedSuite extends FunSuite with Matchers with LocalSparkContext { } test("groupByKey where map output sizes exceed maxMbInFlight") { - val conf = new SparkConf().set("spark.reducer.maxMbInFlight", "1") + val conf = new SparkConf().set("spark.reducer.maxSizeInFlight", "1m") sc = new SparkContext(clusterUrl, "test", conf) // This data should be around 20 MB, so even with 4 mappers and 2 reducers, each map output // file should be about 2.5 MB diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index a69e9b761f9a7..c0439f934813e 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -22,8 +22,7 @@ import java.net.URI import java.util.jar.{JarEntry, JarOutputStream} import javax.net.ssl.SSLException -import com.google.common.io.ByteStreams -import org.apache.commons.io.{FileUtils, IOUtils} +import com.google.common.io.{ByteStreams, Files} import org.apache.commons.lang3.RandomUtils import org.scalatest.FunSuite @@ -239,7 +238,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { def fileTransferTest(server: HttpFileServer, sm: SecurityManager = null): Unit = { val randomContent = RandomUtils.nextBytes(100) val file = File.createTempFile("FileServerSuite", "sslTests", tmpDir) - FileUtils.writeByteArrayToFile(file, randomContent) + Files.write(randomContent, file) server.addFile(file) val uri = new URI(server.serverUri + "/files/" + file.getName) @@ -254,7 +253,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { Utils.setupSecureURLConnection(connection, sm) } - val buf = IOUtils.toByteArray(connection.getInputStream) + val buf = ByteStreams.toByteArray(connection.getInputStream) assert(buf === randomContent) } diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index 0fd570e5297d9..b789912e9ebef 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -48,7 +48,7 @@ class HeartbeatReceiverSuite extends FunSuite with LocalSparkContext { val metrics = new TaskMetrics val blockManagerId = BlockManagerId("executor-1", "localhost", 12345) - val response = receiverRef.askWithReply[HeartbeatResponse]( + val response = receiverRef.askWithRetry[HeartbeatResponse]( Heartbeat("executor-1", Array(1L -> metrics), blockManagerId)) verify(scheduler).executorHeartbeatReceived( @@ -71,7 +71,7 @@ class HeartbeatReceiverSuite extends FunSuite with LocalSparkContext { val metrics = new TaskMetrics val blockManagerId = BlockManagerId("executor-1", "localhost", 12345) - val response = receiverRef.askWithReply[HeartbeatResponse]( + val response = receiverRef.askWithRetry[HeartbeatResponse]( Heartbeat("executor-1", Array(1L -> metrics), blockManagerId)) verify(scheduler).executorHeartbeatReceived( diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index 272e6af0514e4..68d08e32f9aa4 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -24,11 +24,30 @@ import scala.language.postfixOps import scala.util.{Try, Random} import org.scalatest.FunSuite +import org.apache.spark.network.util.ByteUnit import org.apache.spark.serializer.{KryoRegistrator, KryoSerializer} import org.apache.spark.util.{RpcUtils, ResetSystemProperties} import com.esotericsoftware.kryo.Kryo class SparkConfSuite extends FunSuite with LocalSparkContext with ResetSystemProperties { + test("Test byteString conversion") { + val conf = new SparkConf() + // Simply exercise the API, we don't need a complete conversion test since that's handled in + // UtilsSuite.scala + assert(conf.getSizeAsBytes("fake","1k") === ByteUnit.KiB.toBytes(1)) + assert(conf.getSizeAsKb("fake","1k") === ByteUnit.KiB.toKiB(1)) + assert(conf.getSizeAsMb("fake","1k") === ByteUnit.KiB.toMiB(1)) + assert(conf.getSizeAsGb("fake","1k") === ByteUnit.KiB.toGiB(1)) + } + + test("Test timeString conversion") { + val conf = new SparkConf() + // Simply exercise the API, we don't need a complete conversion test since that's handled in + // UtilsSuite.scala + assert(conf.getTimeAsMs("fake","1ms") === TimeUnit.MILLISECONDS.toMillis(1)) + assert(conf.getTimeAsSeconds("fake","1000ms") === TimeUnit.MILLISECONDS.toSeconds(1000)) + } + test("loading from system properties") { System.setProperty("spark.test.testProperty", "2") val conf = new SparkConf() diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 4561e5b8e9663..c4e6f06146b0a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -231,7 +231,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties val childArgsStr = childArgs.mkString(" ") if (useRest) { childArgsStr should endWith ("thejar.jar org.SomeClass arg1 arg2") - mainClass should be ("org.apache.spark.deploy.rest.StandaloneRestClient") + mainClass should be ("org.apache.spark.deploy.rest.RestSubmissionClient") } else { childArgsStr should startWith ("--supervise --memory 4g --cores 5") childArgsStr should include regex "launch spark://h:p .*thejar.jar org.SomeClass arg1 arg2" diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index 8bcca926097a1..1b2b699cb11e6 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.deploy import java.io.{PrintStream, OutputStream, File} +import org.apache.ivy.core.settings.IvySettings + import scala.collection.mutable.ArrayBuffer import org.scalatest.{BeforeAndAfterAll, FunSuite} @@ -56,24 +58,23 @@ class SparkSubmitUtilsSuite extends FunSuite with BeforeAndAfterAll { } test("create repo resolvers") { - val resolver1 = SparkSubmitUtils.createRepoResolvers(None) + val settings = new IvySettings + val res1 = SparkSubmitUtils.createRepoResolvers(None, settings) // should have central and spark-packages by default - assert(resolver1.getResolvers.size() === 2) - assert(resolver1.getResolvers.get(0).asInstanceOf[IBiblioResolver].getName === "central") - assert(resolver1.getResolvers.get(1).asInstanceOf[IBiblioResolver].getName === "spark-packages") + assert(res1.getResolvers.size() === 4) + assert(res1.getResolvers.get(0).asInstanceOf[IBiblioResolver].getName === "local-m2-cache") + assert(res1.getResolvers.get(1).asInstanceOf[IBiblioResolver].getName === "local-ivy-cache") + assert(res1.getResolvers.get(2).asInstanceOf[IBiblioResolver].getName === "central") + assert(res1.getResolvers.get(3).asInstanceOf[IBiblioResolver].getName === "spark-packages") val repos = "a/1,b/2,c/3" - val resolver2 = SparkSubmitUtils.createRepoResolvers(Option(repos)) - assert(resolver2.getResolvers.size() === 5) + val resolver2 = SparkSubmitUtils.createRepoResolvers(Option(repos), settings) + assert(resolver2.getResolvers.size() === 7) val expected = repos.split(",").map(r => s"$r/") resolver2.getResolvers.toArray.zipWithIndex.foreach { case (resolver: IBiblioResolver, i) => - if (i == 0) { - assert(resolver.getName === "central") - } else if (i == 1) { - assert(resolver.getName === "spark-packages") - } else { - assert(resolver.getName === s"repo-${i - 1}") - assert(resolver.getRoot === expected(i - 2)) + if (i > 3) { + assert(resolver.getName === s"repo-${i - 3}") + assert(resolver.getRoot === expected(i - 4)) } } } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index fcae603c7d18e..9e367a0d9af0d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -224,9 +224,9 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers EventLoggingListener.initEventLog(new FileOutputStream(file)) } val writer = new OutputStreamWriter(bstream, "UTF-8") - try { + Utils.tryWithSafeFinally { events.foreach(e => writer.write(compact(render(JsonProtocol.sparkEventToJson(e))) + "\n")) - } finally { + } { writer.close() } } diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index 8e09976636386..0a318a27ac212 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -39,9 +39,9 @@ import org.apache.spark.deploy.master.DriverState._ * Tests for the REST application submission protocol used in standalone cluster mode. */ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { - private val client = new StandaloneRestClient + private val client = new RestSubmissionClient private var actorSystem: Option[ActorSystem] = None - private var server: Option[StandaloneRestServer] = None + private var server: Option[RestSubmissionServer] = None override def afterEach() { actorSystem.foreach(_.shutdown()) @@ -89,7 +89,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { conf.set("spark.app.name", "dreamer") val appArgs = Array("one", "two", "six") // main method calls this - val response = StandaloneRestClient.run("app-resource", "main-class", appArgs, conf) + val response = RestSubmissionClient.run("app-resource", "main-class", appArgs, conf) val submitResponse = getSubmitResponse(response) assert(submitResponse.action === Utils.getFormattedClassName(submitResponse)) assert(submitResponse.serverSparkVersion === SPARK_VERSION) @@ -208,7 +208,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { test("good request paths") { val masterUrl = startSmartServer() val httpUrl = masterUrl.replace("spark://", "http://") - val v = StandaloneRestServer.PROTOCOL_VERSION + val v = RestSubmissionServer.PROTOCOL_VERSION val json = constructSubmitRequest(masterUrl).toJson val submitRequestPath = s"$httpUrl/$v/submissions/create" val killRequestPath = s"$httpUrl/$v/submissions/kill" @@ -238,7 +238,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { test("good request paths, bad requests") { val masterUrl = startSmartServer() val httpUrl = masterUrl.replace("spark://", "http://") - val v = StandaloneRestServer.PROTOCOL_VERSION + val v = RestSubmissionServer.PROTOCOL_VERSION val submitRequestPath = s"$httpUrl/$v/submissions/create" val killRequestPath = s"$httpUrl/$v/submissions/kill" val statusRequestPath = s"$httpUrl/$v/submissions/status" @@ -276,7 +276,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { test("bad request paths") { val masterUrl = startSmartServer() val httpUrl = masterUrl.replace("spark://", "http://") - val v = StandaloneRestServer.PROTOCOL_VERSION + val v = RestSubmissionServer.PROTOCOL_VERSION val (response1, code1) = sendHttpRequestWithResponse(httpUrl, "GET") val (response2, code2) = sendHttpRequestWithResponse(s"$httpUrl/", "GET") val (response3, code3) = sendHttpRequestWithResponse(s"$httpUrl/$v", "GET") @@ -292,7 +292,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { assert(code5 === HttpServletResponse.SC_BAD_REQUEST) assert(code6 === HttpServletResponse.SC_BAD_REQUEST) assert(code7 === HttpServletResponse.SC_BAD_REQUEST) - assert(code8 === StandaloneRestServer.SC_UNKNOWN_PROTOCOL_VERSION) + assert(code8 === RestSubmissionServer.SC_UNKNOWN_PROTOCOL_VERSION) // all responses should be error responses val errorResponse1 = getErrorResponse(response1) val errorResponse2 = getErrorResponse(response2) @@ -310,13 +310,13 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { assert(errorResponse5.highestProtocolVersion === null) assert(errorResponse6.highestProtocolVersion === null) assert(errorResponse7.highestProtocolVersion === null) - assert(errorResponse8.highestProtocolVersion === StandaloneRestServer.PROTOCOL_VERSION) + assert(errorResponse8.highestProtocolVersion === RestSubmissionServer.PROTOCOL_VERSION) } test("server returns unknown fields") { val masterUrl = startSmartServer() val httpUrl = masterUrl.replace("spark://", "http://") - val v = StandaloneRestServer.PROTOCOL_VERSION + val v = RestSubmissionServer.PROTOCOL_VERSION val submitRequestPath = s"$httpUrl/$v/submissions/create" val oldJson = constructSubmitRequest(masterUrl).toJson val oldFields = parse(oldJson).asInstanceOf[JObject].obj @@ -340,7 +340,7 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { test("client handles faulty server") { val masterUrl = startFaultyServer() val httpUrl = masterUrl.replace("spark://", "http://") - val v = StandaloneRestServer.PROTOCOL_VERSION + val v = RestSubmissionServer.PROTOCOL_VERSION val submitRequestPath = s"$httpUrl/$v/submissions/create" val killRequestPath = s"$httpUrl/$v/submissions/kill/anything" val statusRequestPath = s"$httpUrl/$v/submissions/status/anything" @@ -400,9 +400,9 @@ class StandaloneRestSubmitSuite extends FunSuite with BeforeAndAfterEach { val fakeMasterRef = _actorSystem.actorOf(Props(makeFakeMaster)) val _server = if (faulty) { - new FaultyStandaloneRestServer(localhost, 0, fakeMasterRef, "spark://fake:7077", conf) + new FaultyStandaloneRestServer(localhost, 0, conf, fakeMasterRef, "spark://fake:7077") } else { - new StandaloneRestServer(localhost, 0, fakeMasterRef, "spark://fake:7077", conf) + new StandaloneRestServer(localhost, 0, conf, fakeMasterRef, "spark://fake:7077") } val port = _server.start() // set these to clean them up after every test @@ -563,20 +563,18 @@ private class SmarterMaster extends Actor { private class FaultyStandaloneRestServer( host: String, requestedPort: Int, + masterConf: SparkConf, masterActor: ActorRef, - masterUrl: String, - masterConf: SparkConf) - extends StandaloneRestServer(host, requestedPort, masterActor, masterUrl, masterConf) { + masterUrl: String) + extends RestSubmissionServer(host, requestedPort, masterConf) { - protected override val contextToServlet = Map[String, StandaloneRestServlet]( - s"$baseContext/create/*" -> new MalformedSubmitServlet, - s"$baseContext/kill/*" -> new InvalidKillServlet, - s"$baseContext/status/*" -> new ExplodingStatusServlet, - "/*" -> new ErrorServlet - ) + protected override val submitRequestServlet = new MalformedSubmitServlet + protected override val killRequestServlet = new InvalidKillServlet + protected override val statusRequestServlet = new ExplodingStatusServlet /** A faulty servlet that produces malformed responses. */ - class MalformedSubmitServlet extends SubmitRequestServlet(masterActor, masterUrl, masterConf) { + class MalformedSubmitServlet + extends StandaloneSubmitRequestServlet(masterActor, masterUrl, masterConf) { protected override def sendResponse( responseMessage: SubmitRestProtocolResponse, responseServlet: HttpServletResponse): Unit = { @@ -586,7 +584,7 @@ private class FaultyStandaloneRestServer( } /** A faulty servlet that produces invalid responses. */ - class InvalidKillServlet extends KillRequestServlet(masterActor, masterConf) { + class InvalidKillServlet extends StandaloneKillRequestServlet(masterActor, masterConf) { protected override def handleKill(submissionId: String): KillSubmissionResponse = { val k = super.handleKill(submissionId) k.submissionId = null @@ -595,7 +593,7 @@ private class FaultyStandaloneRestServer( } /** A faulty status servlet that explodes. */ - class ExplodingStatusServlet extends StatusRequestServlet(masterActor, masterConf) { + class ExplodingStatusServlet extends StandaloneStatusRequestServlet(masterActor, masterConf) { private def explode: Int = 1 / 0 protected override def handleStatus(submissionId: String): SubmissionStatusResponse = { val s = super.handleStatus(submissionId) diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index 190b08d950a02..ef3e213f1fcce 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -21,7 +21,7 @@ import java.io.{File, FileWriter, PrintWriter} import scala.collection.mutable.ArrayBuffer -import org.apache.commons.lang.math.RandomUtils +import org.apache.commons.lang3.RandomUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.{LongWritable, Text} @@ -60,7 +60,7 @@ class InputOutputMetricsSuite extends FunSuite with SharedSparkContext tmpFile = new File(testTempDir, getClass.getSimpleName + ".txt") val pw = new PrintWriter(new FileWriter(tmpFile)) for (x <- 1 to numRecords) { - pw.println(RandomUtils.nextInt(numBuckets)) + pw.println(RandomUtils.nextInt(0, numBuckets)) } pw.close() diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index 94bfa67451892..46d2e5173acae 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -17,14 +17,16 @@ package org.apache.spark.network.netty +import java.io.InputStreamReader import java.nio._ +import java.nio.charset.Charset import java.util.concurrent.TimeUnit import scala.concurrent.duration._ import scala.concurrent.{Await, Promise} import scala.util.{Failure, Success, Try} -import org.apache.commons.io.IOUtils +import com.google.common.io.CharStreams import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.shuffle.BlockFetchingListener import org.apache.spark.network.{BlockDataManager, BlockTransferService} @@ -32,7 +34,7 @@ import org.apache.spark.storage.{BlockId, ShuffleBlockId} import org.apache.spark.{SecurityManager, SparkConf} import org.mockito.Mockito._ import org.scalatest.mock.MockitoSugar -import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite, ShouldMatchers} +import org.scalatest.{FunSuite, ShouldMatchers} class NettyBlockTransferSecuritySuite extends FunSuite with MockitoSugar with ShouldMatchers { test("security default off") { @@ -113,7 +115,9 @@ class NettyBlockTransferSecuritySuite extends FunSuite with MockitoSugar with Sh val result = fetchBlock(exec0, exec1, "1", blockId) match { case Success(buf) => - IOUtils.toString(buf.createInputStream()) should equal(blockString) + val actualString = CharStreams.toString( + new InputStreamReader(buf.createInputStream(), Charset.forName("UTF-8"))) + actualString should equal(blockString) buf.release() Success() case Failure(t) => diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala index aea76c1adcc09..85eb2a1d07ba4 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala @@ -176,7 +176,7 @@ class PipedRDDSuite extends FunSuite with SharedSparkContext { } val hadoopPart1 = generateFakeHadoopPartition() val pipedRdd = new PipedRDD(nums, "printenv " + varName) - val tContext = new TaskContextImpl(0, 0, 0, 0) + val tContext = new TaskContextImpl(0, 0, 0, 0, null) val rddIter = pipedRdd.compute(hadoopPart1, tContext) val arr = rddIter.toArray assert(arr(0) == "/some/path") diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index df42faab64505..ef8c36a28655b 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -99,6 +99,27 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(sc.union(Seq(nums, nums)).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) } + test("SparkContext.union creates UnionRDD if at least one RDD has no partitioner") { + val rddWithPartitioner = sc.parallelize(Seq(1->true)).partitionBy(new HashPartitioner(1)) + val rddWithNoPartitioner = sc.parallelize(Seq(2->true)) + val unionRdd = sc.union(rddWithNoPartitioner, rddWithPartitioner) + assert(unionRdd.isInstanceOf[UnionRDD[_]]) + } + + test("SparkContext.union creates PartitionAwareUnionRDD if all RDDs have partitioners") { + val rddWithPartitioner = sc.parallelize(Seq(1->true)).partitionBy(new HashPartitioner(1)) + val unionRdd = sc.union(rddWithPartitioner, rddWithPartitioner) + assert(unionRdd.isInstanceOf[PartitionerAwareUnionRDD[_]]) + } + + test("PartitionAwareUnionRDD raises exception if at least one RDD has no partitioner") { + val rddWithPartitioner = sc.parallelize(Seq(1->true)).partitionBy(new HashPartitioner(1)) + val rddWithNoPartitioner = sc.parallelize(Seq(2->true)) + intercept[IllegalArgumentException] { + new PartitionerAwareUnionRDD(sc, Seq(rddWithNoPartitioner, rddWithPartitioner)) + } + } + test("partitioner aware union") { def makeRDDWithPartitioner(seq: Seq[Int]): RDD[Int] = { sc.makeRDD(seq, 1) diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 44c88b00c442a..ae3339d80f9c6 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -100,8 +100,8 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { } val rpcEndpointRef = env.setupEndpoint("send-ref", endpoint) - val newRpcEndpointRef = rpcEndpointRef.askWithReply[RpcEndpointRef]("Hello") - val reply = newRpcEndpointRef.askWithReply[String]("Echo") + val newRpcEndpointRef = rpcEndpointRef.askWithRetry[RpcEndpointRef]("Hello") + val reply = newRpcEndpointRef.askWithRetry[String]("Echo") assert("Echo" === reply) } @@ -115,7 +115,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { } } }) - val reply = rpcEndpointRef.askWithReply[String]("hello") + val reply = rpcEndpointRef.askWithRetry[String]("hello") assert("hello" === reply) } @@ -134,7 +134,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-remotely") try { - val reply = rpcEndpointRef.askWithReply[String]("hello") + val reply = rpcEndpointRef.askWithRetry[String]("hello") assert("hello" === reply) } finally { anotherEnv.shutdown() @@ -162,7 +162,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-timeout") try { val e = intercept[Exception] { - rpcEndpointRef.askWithReply[String]("hello", 1 millis) + rpcEndpointRef.askWithRetry[String]("hello", 1 millis) } assert(e.isInstanceOf[TimeoutException] || e.getCause.isInstanceOf[TimeoutException]) } finally { @@ -399,7 +399,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { } }) - val f = endpointRef.sendWithReply[String]("Hi") + val f = endpointRef.ask[String]("Hi") val ack = Await.result(f, 5 seconds) assert("ack" === ack) @@ -419,7 +419,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "sendWithReply-remotely") try { - val f = rpcEndpointRef.sendWithReply[String]("hello") + val f = rpcEndpointRef.ask[String]("hello") val ack = Await.result(f, 5 seconds) assert("ack" === ack) } finally { @@ -437,7 +437,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { } }) - val f = endpointRef.sendWithReply[String]("Hi") + val f = endpointRef.ask[String]("Hi") val e = intercept[SparkException] { Await.result(f, 5 seconds) } @@ -460,7 +460,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val rpcEndpointRef = anotherEnv.setupEndpointRef( "local", env.address, "sendWithReply-remotely-error") try { - val f = rpcEndpointRef.sendWithReply[String]("hello") + val f = rpcEndpointRef.ask[String]("hello") val e = intercept[SparkException] { Await.result(f, 5 seconds) } @@ -529,7 +529,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val rpcEndpointRef = anotherEnv.setupEndpointRef( "local", env.address, "sendWithReply-unserializable-error") try { - val f = rpcEndpointRef.sendWithReply[String]("hello") + val f = rpcEndpointRef.ask[String]("hello") intercept[TimeoutException] { Await.result(f, 1 seconds) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 057e226916027..83ae8701243e5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -51,7 +51,7 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte } test("all TaskCompletionListeners should be called even if some fail") { - val context = new TaskContextImpl(0, 0, 0, 0) + val context = new TaskContextImpl(0, 0, 0, 0, null) val listener = mock(classOf[TaskCompletionListener]) context.addTaskCompletionListener(_ => throw new Exception("blah")) context.addTaskCompletionListener(listener) diff --git a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala new file mode 100644 index 0000000000000..f28e29e9b8d8e --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.mesos + +import java.util.Date + +import org.scalatest.FunSuite +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.deploy.Command +import org.apache.spark.deploy.mesos.MesosDriverDescription +import org.apache.spark.scheduler.cluster.mesos._ +import org.apache.spark.{LocalSparkContext, SparkConf} + + +class MesosClusterSchedulerSuite extends FunSuite with LocalSparkContext with MockitoSugar { + + private val command = new Command("mainClass", Seq("arg"), null, null, null, null) + + test("can queue drivers") { + val conf = new SparkConf() + conf.setMaster("mesos://localhost:5050") + conf.setAppName("spark mesos") + val scheduler = new MesosClusterScheduler( + new BlackHoleMesosClusterPersistenceEngineFactory, conf) { + override def start(): Unit = { ready = true } + } + scheduler.start() + val response = scheduler.submitDriver( + new MesosDriverDescription("d1", "jar", 1000, 1, true, + command, Map[String, String](), "s1", new Date())) + assert(response.success) + val response2 = + scheduler.submitDriver(new MesosDriverDescription( + "d1", "jar", 1000, 1, true, command, Map[String, String](), "s2", new Date())) + assert(response2.success) + val state = scheduler.getSchedulerState() + val queuedDrivers = state.queuedDrivers.toList + assert(queuedDrivers(0).submissionId == response.submissionId) + assert(queuedDrivers(1).submissionId == response2.submissionId) + } + + test("can kill queued drivers") { + val conf = new SparkConf() + conf.setMaster("mesos://localhost:5050") + conf.setAppName("spark mesos") + val scheduler = new MesosClusterScheduler( + new BlackHoleMesosClusterPersistenceEngineFactory, conf) { + override def start(): Unit = { ready = true } + } + scheduler.start() + val response = scheduler.submitDriver( + new MesosDriverDescription("d1", "jar", 1000, 1, true, + command, Map[String, String](), "s1", new Date())) + assert(response.success) + val killResponse = scheduler.killDriver(response.submissionId) + assert(killResponse.success) + val state = scheduler.getSchedulerState() + assert(state.queuedDrivers.isEmpty) + } +} diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala index 967c9e9899c9d..da98d09184735 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala @@ -33,8 +33,8 @@ class KryoSerializerResizableOutputSuite extends FunSuite { test("kryo without resizable output buffer should fail on large array") { val conf = new SparkConf(false) conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - conf.set("spark.kryoserializer.buffer.mb", "1") - conf.set("spark.kryoserializer.buffer.max.mb", "1") + conf.set("spark.kryoserializer.buffer", "1m") + conf.set("spark.kryoserializer.buffer.max", "1m") val sc = new SparkContext("local", "test", conf) intercept[SparkException](sc.parallelize(x).collect()) LocalSparkContext.stop(sc) @@ -43,8 +43,8 @@ class KryoSerializerResizableOutputSuite extends FunSuite { test("kryo with resizable output buffer should succeed on large array") { val conf = new SparkConf(false) conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - conf.set("spark.kryoserializer.buffer.mb", "1") - conf.set("spark.kryoserializer.buffer.max.mb", "2") + conf.set("spark.kryoserializer.buffer", "1m") + conf.set("spark.kryoserializer.buffer.max", "2m") val sc = new SparkContext("local", "test", conf) assert(sc.parallelize(x).collect() === x) LocalSparkContext.stop(sc) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index b070a54aa989b..1b13559e77cb8 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -269,7 +269,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { test("serialization buffer overflow reporting") { import org.apache.spark.SparkException - val kryoBufferMaxProperty = "spark.kryoserializer.buffer.max.mb" + val kryoBufferMaxProperty = "spark.kryoserializer.buffer.max" val largeObject = (1 to 1000000).toArray diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index ffa5162a31841..f647200402ecb 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -50,7 +50,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd val allStores = new ArrayBuffer[BlockManager] // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test - conf.set("spark.kryoserializer.buffer.mb", "1") + conf.set("spark.kryoserializer.buffer", "1m") val serializer = new KryoSerializer(conf) // Implicitly convert strings to BlockIds for test clarity. diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 7d82a7c66ad1a..f5b410f41da27 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -55,7 +55,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach val shuffleManager = new HashShuffleManager(conf) // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test - conf.set("spark.kryoserializer.buffer.mb", "1") + conf.set("spark.kryoserializer.buffer", "1m") val serializer = new KryoSerializer(conf) // Implicitly convert strings to BlockIds for test clarity. @@ -356,7 +356,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach master.removeExecutor(store.blockManagerId.executorId) assert(master.getLocations("a1").size == 0, "a1 was not removed from master") - val reregister = !master.driverEndpoint.askWithReply[Boolean]( + val reregister = !master.driverEndpoint.askWithRetry[Boolean]( BlockManagerHeartbeat(store.blockManagerId)) assert(reregister == true) } @@ -814,14 +814,14 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach // be nice to refactor classes involved in disk storage in a way that // allows for easier testing. val blockManager = mock(classOf[BlockManager]) - when(blockManager.conf).thenReturn(conf.clone.set(confKey, 0.toString)) + when(blockManager.conf).thenReturn(conf.clone.set(confKey, "0")) val diskBlockManager = new DiskBlockManager(blockManager, conf) val diskStoreMapped = new DiskStore(blockManager, diskBlockManager) diskStoreMapped.putBytes(blockId, byteBuffer, StorageLevel.DISK_ONLY) val mapped = diskStoreMapped.getBytes(blockId).get - when(blockManager.conf).thenReturn(conf.clone.set(confKey, (1000 * 1000).toString)) + when(blockManager.conf).thenReturn(conf.clone.set(confKey, "1m")) val diskStoreNotMapped = new DiskStore(blockManager, diskBlockManager) diskStoreNotMapped.putBytes(blockId, byteBuffer, StorageLevel.DISK_ONLY) val notMapped = diskStoreNotMapped.getBytes(blockId).get diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 37b593b2c5f79..2080c432d77db 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -89,7 +89,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { ) val iterator = new ShuffleBlockFetcherIterator( - new TaskContextImpl(0, 0, 0, 0), + new TaskContextImpl(0, 0, 0, 0, null), transfer, blockManager, blocksByAddress, @@ -154,7 +154,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) - val taskContext = new TaskContextImpl(0, 0, 0, 0) + val taskContext = new TaskContextImpl(0, 0, 0, 0, null) val iterator = new ShuffleBlockFetcherIterator( taskContext, transfer, @@ -217,7 +217,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) - val taskContext = new TaskContextImpl(0, 0, 0, 0) + val taskContext = new TaskContextImpl(0, 0, 0, 0, null) val iterator = new ShuffleBlockFetcherIterator( taskContext, transfer, diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala index 67a9f75ff2187..28915bd53354e 100644 --- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.util +import scala.collection.mutable.ArrayBuffer + import org.scalatest.{BeforeAndAfterEach, BeforeAndAfterAll, FunSuite, PrivateMethodTester} class DummyClass1 {} @@ -96,6 +98,22 @@ class SizeEstimatorSuite // Past size 100, our samples 100 elements, but we should still get the right size. assertResult(28016)(SizeEstimator.estimate(Array.fill(1000)(new DummyClass3))) + + val arr = new Array[Char](100000) + assertResult(200016)(SizeEstimator.estimate(arr)) + assertResult(480032)(SizeEstimator.estimate(Array.fill(10000)(new DummyString(arr)))) + + val buf = new ArrayBuffer[DummyString]() + for (i <- 0 until 5000) { + buf.append(new DummyString(new Array[Char](10))) + } + assertResult(340016)(SizeEstimator.estimate(buf.toArray)) + + for (i <- 0 until 5000) { + buf.append(new DummyString(arr)) + } + assertResult(683912)(SizeEstimator.estimate(buf.toArray)) + // If an array contains the *same* element many times, we should only count it once. val d1 = new DummyClass1 // 10 pointers plus 8-byte object diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 1ba99803f5a0e..62a3cbcdf69ea 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -23,7 +23,6 @@ import java.nio.{ByteBuffer, ByteOrder} import java.text.DecimalFormatSymbols import java.util.concurrent.TimeUnit import java.util.Locale -import java.util.PriorityQueue import scala.collection.mutable.ListBuffer import scala.util.Random @@ -35,6 +34,7 @@ import org.scalatest.FunSuite import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.spark.network.util.ByteUnit import org.apache.spark.SparkConf class UtilsSuite extends FunSuite with ResetSystemProperties { @@ -65,6 +65,10 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { assert(Utils.timeStringAsMs("1d") === TimeUnit.DAYS.toMillis(1)) // Test invalid strings + intercept[NumberFormatException] { + Utils.timeStringAsMs("600l") + } + intercept[NumberFormatException] { Utils.timeStringAsMs("This breaks 600s") } @@ -82,6 +86,100 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { } } + test("Test byteString conversion") { + // Test zero + assert(Utils.byteStringAsBytes("0") === 0) + + assert(Utils.byteStringAsGb("1") === 1) + assert(Utils.byteStringAsGb("1g") === 1) + assert(Utils.byteStringAsGb("1023m") === 0) + assert(Utils.byteStringAsGb("1024m") === 1) + assert(Utils.byteStringAsGb("1048575k") === 0) + assert(Utils.byteStringAsGb("1048576k") === 1) + assert(Utils.byteStringAsGb("1k") === 0) + assert(Utils.byteStringAsGb("1t") === ByteUnit.TiB.toGiB(1)) + assert(Utils.byteStringAsGb("1p") === ByteUnit.PiB.toGiB(1)) + + assert(Utils.byteStringAsMb("1") === 1) + assert(Utils.byteStringAsMb("1m") === 1) + assert(Utils.byteStringAsMb("1048575b") === 0) + assert(Utils.byteStringAsMb("1048576b") === 1) + assert(Utils.byteStringAsMb("1023k") === 0) + assert(Utils.byteStringAsMb("1024k") === 1) + assert(Utils.byteStringAsMb("3645k") === 3) + assert(Utils.byteStringAsMb("1024gb") === 1048576) + assert(Utils.byteStringAsMb("1g") === ByteUnit.GiB.toMiB(1)) + assert(Utils.byteStringAsMb("1t") === ByteUnit.TiB.toMiB(1)) + assert(Utils.byteStringAsMb("1p") === ByteUnit.PiB.toMiB(1)) + + assert(Utils.byteStringAsKb("1") === 1) + assert(Utils.byteStringAsKb("1k") === 1) + assert(Utils.byteStringAsKb("1m") === ByteUnit.MiB.toKiB(1)) + assert(Utils.byteStringAsKb("1g") === ByteUnit.GiB.toKiB(1)) + assert(Utils.byteStringAsKb("1t") === ByteUnit.TiB.toKiB(1)) + assert(Utils.byteStringAsKb("1p") === ByteUnit.PiB.toKiB(1)) + + assert(Utils.byteStringAsBytes("1") === 1) + assert(Utils.byteStringAsBytes("1k") === ByteUnit.KiB.toBytes(1)) + assert(Utils.byteStringAsBytes("1m") === ByteUnit.MiB.toBytes(1)) + assert(Utils.byteStringAsBytes("1g") === ByteUnit.GiB.toBytes(1)) + assert(Utils.byteStringAsBytes("1t") === ByteUnit.TiB.toBytes(1)) + assert(Utils.byteStringAsBytes("1p") === ByteUnit.PiB.toBytes(1)) + + // Overflow handling, 1073741824p exceeds Long.MAX_VALUE if converted straight to Bytes + // This demonstrates that we can have e.g 1024^3 PB without overflowing. + assert(Utils.byteStringAsGb("1073741824p") === ByteUnit.PiB.toGiB(1073741824)) + assert(Utils.byteStringAsMb("1073741824p") === ByteUnit.PiB.toMiB(1073741824)) + + // Run this to confirm it doesn't throw an exception + assert(Utils.byteStringAsBytes("9223372036854775807") === 9223372036854775807L) + assert(ByteUnit.PiB.toPiB(9223372036854775807L) === 9223372036854775807L) + + // Test overflow exception + intercept[IllegalArgumentException] { + // This value exceeds Long.MAX when converted to bytes + Utils.byteStringAsBytes("9223372036854775808") + } + + // Test overflow exception + intercept[IllegalArgumentException] { + // This value exceeds Long.MAX when converted to TB + ByteUnit.PiB.toTiB(9223372036854775807L) + } + + // Test fractional string + intercept[NumberFormatException] { + Utils.byteStringAsMb("0.064") + } + + // Test fractional string + intercept[NumberFormatException] { + Utils.byteStringAsMb("0.064m") + } + + // Test invalid strings + intercept[NumberFormatException] { + Utils.byteStringAsBytes("500ub") + } + + // Test invalid strings + intercept[NumberFormatException] { + Utils.byteStringAsBytes("This breaks 600b") + } + + intercept[NumberFormatException] { + Utils.byteStringAsBytes("This breaks 600") + } + + intercept[NumberFormatException] { + Utils.byteStringAsBytes("600gb This breaks") + } + + intercept[NumberFormatException] { + Utils.byteStringAsBytes("This 123mb breaks") + } + } + test("bytesToString") { assert(Utils.bytesToString(10) === "10.0 B") assert(Utils.bytesToString(1500) === "1500.0 B") diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index b5a67dd783b93..3dbb35f7054a2 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -119,7 +119,7 @@ if [[ ! "$@" =~ --skip-publish ]]; then rm -rf $SPARK_REPO build/mvn -DskipTests -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ - -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ + -Pyarn -Phive -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ clean install ./dev/change-version-to-2.11.sh diff --git a/docs/configuration.md b/docs/configuration.md index d587b91124cb8..72105feba4919 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -48,6 +48,17 @@ The following format is accepted: 5d (days) 1y (years) + +Properties that specify a byte size should be configured with a unit of size. +The following format is accepted: + + 1b (bytes) + 1k or 1kb (kibibytes = 1024 bytes) + 1m or 1mb (mebibytes = 1024 kibibytes) + 1g or 1gb (gibibytes = 1024 mebibytes) + 1t or 1tb (tebibytes = 1024 gibibytes) + 1p or 1pb (pebibytes = 1024 tebibytes) + ## Dynamically Loading Spark Properties In some cases, you may want to avoid hard-coding certain configurations in a `SparkConf`. For instance, if you'd like to run the same application with different masters or different @@ -272,12 +283,11 @@ Apart from these, the following properties are also available, and may be useful - spark.executor.logs.rolling.size.maxBytes + spark.executor.logs.rolling.maxSize (none) Set the max size of the file by which the executor logs will be rolled over. - Rolling is disabled by default. Value is set in terms of bytes. - See spark.executor.logs.rolling.maxRetainedFiles + Rolling is disabled by default. See spark.executor.logs.rolling.maxRetainedFiles for automatic cleaning of old logs. @@ -366,10 +376,10 @@ Apart from these, the following properties are also available, and may be useful - - + + @@ -403,10 +413,10 @@ Apart from these, the following properties are also available, and may be useful - - + + @@ -582,18 +592,18 @@ Apart from these, the following properties are also available, and may be useful - - + + - - + + @@ -641,19 +651,19 @@ Apart from these, the following properties are also available, and may be useful - - + + - - + + @@ -698,9 +708,9 @@ Apart from these, the following properties are also available, and may be useful - + @@ -816,9 +826,9 @@ Apart from these, the following properties are also available, and may be useful - + diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md index 963e88a3e1d8f..8d9c2ba2041b2 100644 --- a/docs/job-scheduling.md +++ b/docs/job-scheduling.md @@ -32,7 +32,7 @@ Resource allocation can be configured as follows, based on the cluster type: * **Standalone mode:** By default, applications submitted to the standalone mode cluster will run in FIFO (first-in-first-out) order, and each application will try to use all available nodes. You can limit the number of nodes an application uses by setting the `spark.cores.max` configuration property in it, - or change the default for applications that don't set this setting through `spark.deploy.defaultCores`. + or change the default for applications that don't set this setting through `spark.deploy.defaultCores`. Finally, in addition to controlling cores, each application's `spark.executor.memory` setting controls its memory use. * **Mesos:** To use static partitioning on Mesos, set the `spark.mesos.coarse` configuration property to `true`, diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 594bf78b67713..8f53d8201a089 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -78,6 +78,9 @@ To verify that the Mesos cluster is ready for Spark, navigate to the Mesos maste To use Mesos from Spark, you need a Spark binary package available in a place accessible by Mesos, and a Spark driver program configured to connect to Mesos. +Alternatively, you can also install Spark in the same location in all the Mesos slaves, and configure +`spark.mesos.executor.home` (defaults to SPARK_HOME) to point to that location. + ## Uploading Spark Package When Mesos runs a task on a Mesos slave for the first time, that slave must have a Spark binary @@ -107,7 +110,11 @@ the `make-distribution.sh` script included in a Spark source tarball/checkout. The Master URLs for Mesos are in the form `mesos://host:5050` for a single-master Mesos cluster, or `mesos://zk://host:2181` for a multi-master Mesos cluster using ZooKeeper. -The driver also needs some configuration in `spark-env.sh` to interact properly with Mesos: +## Client Mode + +In client mode, a Spark Mesos framework is launched directly on the client machine and waits for the driver output. + +The driver needs some configuration in `spark-env.sh` to interact properly with Mesos: 1. In `spark-env.sh` set some environment variables: * `export MESOS_NATIVE_JAVA_LIBRARY=`. This path is typically @@ -129,8 +136,7 @@ val sc = new SparkContext(conf) {% endhighlight %} (You can also use [`spark-submit`](submitting-applications.html) and configure `spark.executor.uri` -in the [conf/spark-defaults.conf](configuration.html#loading-default-configurations) file. Note -that `spark-submit` currently only supports deploying the Spark driver in `client` mode for Mesos.) +in the [conf/spark-defaults.conf](configuration.html#loading-default-configurations) file.) When running a shell, the `spark.executor.uri` parameter is inherited from `SPARK_EXECUTOR_URI`, so it does not need to be redundantly passed in as a system property. @@ -139,6 +145,17 @@ it does not need to be redundantly passed in as a system property. ./bin/spark-shell --master mesos://host:5050 {% endhighlight %} +## Cluster mode + +Spark on Mesos also supports cluster mode, where the driver is launched in the cluster and the client +can find the results of the driver from the Mesos Web UI. + +To use cluster mode, you must start the MesosClusterDispatcher in your cluster via the `sbin/start-mesos-dispatcher.sh` script, +passing in the Mesos master url (e.g: mesos://host:5050). + +From the client, you can submit a job to Mesos cluster by running `spark-submit` and specifying the master url +to the url of the MesosClusterDispatcher (e.g: mesos://dispatcher:7077). You can view driver statuses on the +Spark cluster Web UI. # Mesos Run Modes diff --git a/docs/tuning.md b/docs/tuning.md index cbd227868b248..1cb223e74f382 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -60,7 +60,7 @@ val sc = new SparkContext(conf) The [Kryo documentation](https://github.com/EsotericSoftware/kryo) describes more advanced registration options, such as adding custom serialization code. -If your objects are large, you may also need to increase the `spark.kryoserializer.buffer.mb` +If your objects are large, you may also need to increase the `spark.kryoserializer.buffer` config property. The default is 2, but this value needs to be large enough to hold the *largest* object you will serialize. diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java index 36207ae38d9a9..fd53c81cc4974 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java @@ -58,7 +58,7 @@ public Tuple2 call(Tuple2 doc_id) { corpus.cache(); // Cluster the documents into three topics using LDA - DistributedLDAModel ldaModel = new LDA().setK(3).run(corpus); + DistributedLDAModel ldaModel = (DistributedLDAModel)new LDA().setK(3).run(corpus); // Output topics. Each is a distribution over words (matching word count vectors) System.out.println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize() diff --git a/examples/src/main/python/hbase_inputformat.py b/examples/src/main/python/hbase_inputformat.py index e17819d5feb76..5b82a14fba413 100644 --- a/examples/src/main/python/hbase_inputformat.py +++ b/examples/src/main/python/hbase_inputformat.py @@ -54,8 +54,9 @@ Run with example jar: ./bin/spark-submit --driver-class-path /path/to/example/jar \ - /path/to/examples/hbase_inputformat.py
Property NameDefaultMeaning
spark.reducer.maxMbInFlight48spark.reducer.maxSizeInFlight48m - Maximum size (in megabytes) of map outputs to fetch simultaneously from each reduce task. Since + Maximum size of map outputs to fetch simultaneously from each reduce task. Since each output requires us to create a buffer to receive it, this represents a fixed memory overhead per reduce task, so keep it small unless you have a large amount of memory.
spark.shuffle.file.buffer.kb32spark.shuffle.file.buffer32k - Size of the in-memory buffer for each shuffle file output stream, in kilobytes. These buffers + Size of the in-memory buffer for each shuffle file output stream. These buffers reduce the number of disk seeks and system calls made in creating intermediate shuffle files.
spark.io.compression.lz4.block.size32768spark.io.compression.lz4.blockSize32k - Block size (in bytes) used in LZ4 compression, in the case when LZ4 compression codec + Block size used in LZ4 compression, in the case when LZ4 compression codec is used. Lowering this block size will also lower shuffle memory usage when LZ4 is used.
spark.io.compression.snappy.block.size32768spark.io.compression.snappy.blockSize32k - Block size (in bytes) used in Snappy compression, in the case when Snappy compression codec + Block size used in Snappy compression, in the case when Snappy compression codec is used. Lowering this block size will also lower shuffle memory usage when Snappy is used.
spark.kryoserializer.buffer.max.mb64spark.kryoserializer.buffer.max64m - Maximum allowable size of Kryo serialization buffer, in megabytes. This must be larger than any + Maximum allowable size of Kryo serialization buffer. This must be larger than any object you attempt to serialize. Increase this if you get a "buffer limit exceeded" exception inside Kryo.
spark.kryoserializer.buffer.mb0.064spark.kryoserializer.buffer64k - Initial size of Kryo's serialization buffer, in megabytes. Note that there will be one buffer + Initial size of Kryo's serialization buffer. Note that there will be one buffer per core on each worker. This buffer will grow up to spark.kryoserializer.buffer.max.mb if needed.
Property NameDefaultMeaning
spark.broadcast.blockSize40964m - Size of each piece of a block in kilobytes for TorrentBroadcastFactory. + Size of each piece of a block for TorrentBroadcastFactory. Too large a value decreases parallelism during broadcast (makes it slower); however, if it is too small, BlockManager might take a performance hit.
spark.storage.memoryMapThreshold20971522m - Size of a block, in bytes, above which Spark memory maps when reading a block from disk. + Size of a block above which Spark memory maps when reading a block from disk. This prevents Spark from memory mapping very small blocks. In general, memory mapping has high overhead for blocks close to or below the page size of the operating system.
+ /path/to/examples/hbase_inputformat.py
[] Assumes you have some data in HBase already, running on , in
+ optionally, you can specify parent znode for your hbase cluster - """, file=sys.stderr) exit(-1) @@ -64,6 +65,9 @@ sc = SparkContext(appName="HBaseInputFormat") conf = {"hbase.zookeeper.quorum": host, "hbase.mapreduce.inputtable": table} + if len(sys.argv) > 3: + conf = {"hbase.zookeeper.quorum": host, "zookeeper.znode.parent": sys.argv[3], + "hbase.mapreduce.inputtable": table} keyConv = "org.apache.spark.examples.pythonconverters.ImmutableBytesWritableToStringConverter" valueConv = "org.apache.spark.examples.pythonconverters.HBaseResultToStringConverter" diff --git a/examples/src/main/python/streaming/direct_kafka_wordcount.py b/examples/src/main/python/streaming/direct_kafka_wordcount.py new file mode 100644 index 0000000000000..6ef188a220c51 --- /dev/null +++ b/examples/src/main/python/streaming/direct_kafka_wordcount.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in UTF8 encoded, '\n' delimited text directly received from Kafka in every 2 seconds. + Usage: direct_kafka_wordcount.py + + To run this on your local machine, you need to setup Kafka and create a producer first, see + http://kafka.apache.org/documentation.html#quickstart + + and then run the example + `$ bin/spark-submit --jars external/kafka-assembly/target/scala-*/\ + spark-streaming-kafka-assembly-*.jar \ + examples/src/main/python/streaming/direct_kafka_wordcount.py \ + localhost:9092 test` +""" + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +from pyspark.streaming.kafka import KafkaUtils + +if __name__ == "__main__": + if len(sys.argv) != 3: + print >> sys.stderr, "Usage: direct_kafka_wordcount.py " + exit(-1) + + sc = SparkContext(appName="PythonStreamingDirectKafkaWordCount") + ssc = StreamingContext(sc, 2) + + brokers, topic = sys.argv[1:] + kvs = KafkaUtils.createDirectStream(ssc, [topic], {"metadata.broker.list": brokers}) + lines = kvs.map(lambda x: x[1]) + counts = lines.flatMap(lambda line: line.split(" ")) \ + .map(lambda word: (word, 1)) \ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index 08a93595a2e17..a1850390c0a86 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -26,7 +26,7 @@ import scopt.OptionParser import org.apache.log4j.{Level, Logger} import org.apache.spark.{SparkContext, SparkConf} -import org.apache.spark.mllib.clustering.LDA +import org.apache.spark.mllib.clustering.{DistributedLDAModel, LDA} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.rdd.RDD @@ -137,7 +137,7 @@ object LDAExample { sc.setCheckpointDir(params.checkpointDir.get) } val startTime = System.nanoTime() - val ldaModel = lda.run(corpus) + val ldaModel = lda.run(corpus).asInstanceOf[DistributedLDAModel] val elapsed = (System.nanoTime() - startTime) / 1e9 println(s"Finished training LDA model. Summary:") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index 0bc36ea65e1ab..99588b0984ab2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -100,7 +100,7 @@ object MovieLensALS { val conf = new SparkConf().setAppName(s"MovieLensALS with $params") if (params.kryo) { conf.registerKryoClasses(Array(classOf[mutable.BitSet], classOf[Rating])) - .set("spark.kryoserializer.buffer.mb", "8") + .set("spark.kryoserializer.buffer", "8m") } val sc = new SparkContext(conf) diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 67907bbfb6d1b..1f3e619d97a24 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -35,6 +35,10 @@ http://spark.apache.org/ + + org.apache.commons + commons-lang3 + org.apache.flume flume-ng-sdk diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala index 4373be443e67d..fd01807fc3ac4 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala @@ -21,9 +21,9 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable -import org.apache.flume.Channel -import org.apache.commons.lang.RandomStringUtils import com.google.common.util.concurrent.ThreadFactoryBuilder +import org.apache.flume.Channel +import org.apache.commons.lang3.RandomStringUtils /** * Class that implements the SparkFlumeProtocol, that is used by the Avro Netty Server to process diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index 5a9bd4214cf51..0721ddaf7055a 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -21,6 +21,7 @@ import java.lang.{Integer => JInt} import java.lang.{Long => JLong} import java.util.{Map => JMap} import java.util.{Set => JSet} +import java.util.{List => JList} import scala.reflect.ClassTag import scala.collection.JavaConversions._ @@ -234,7 +235,6 @@ object KafkaUtils { new KafkaRDD[K, V, KD, VD, R](sc, kafkaParams, offsetRanges, leaderMap, messageHandler) } - /** * Create a RDD from Kafka using offset ranges for each topic and partition. * @@ -558,4 +558,94 @@ private class KafkaUtilsPythonHelper { topics, storageLevel) } + + def createRDD( + jsc: JavaSparkContext, + kafkaParams: JMap[String, String], + offsetRanges: JList[OffsetRange], + leaders: JMap[TopicAndPartition, Broker]): JavaPairRDD[Array[Byte], Array[Byte]] = { + val messageHandler = new JFunction[MessageAndMetadata[Array[Byte], Array[Byte]], + (Array[Byte], Array[Byte])] { + def call(t1: MessageAndMetadata[Array[Byte], Array[Byte]]): (Array[Byte], Array[Byte]) = + (t1.key(), t1.message()) + } + + val jrdd = KafkaUtils.createRDD[ + Array[Byte], + Array[Byte], + DefaultDecoder, + DefaultDecoder, + (Array[Byte], Array[Byte])]( + jsc, + classOf[Array[Byte]], + classOf[Array[Byte]], + classOf[DefaultDecoder], + classOf[DefaultDecoder], + classOf[(Array[Byte], Array[Byte])], + kafkaParams, + offsetRanges.toArray(new Array[OffsetRange](offsetRanges.size())), + leaders, + messageHandler + ) + new JavaPairRDD(jrdd.rdd) + } + + def createDirectStream( + jssc: JavaStreamingContext, + kafkaParams: JMap[String, String], + topics: JSet[String], + fromOffsets: JMap[TopicAndPartition, JLong] + ): JavaPairInputDStream[Array[Byte], Array[Byte]] = { + + if (!fromOffsets.isEmpty) { + import scala.collection.JavaConversions._ + val topicsFromOffsets = fromOffsets.keySet().map(_.topic) + if (topicsFromOffsets != topics.toSet) { + throw new IllegalStateException(s"The specified topics: ${topics.toSet.mkString(" ")} " + + s"do not equal to the topic from offsets: ${topicsFromOffsets.mkString(" ")}") + } + } + + if (fromOffsets.isEmpty) { + KafkaUtils.createDirectStream[Array[Byte], Array[Byte], DefaultDecoder, DefaultDecoder]( + jssc, + classOf[Array[Byte]], + classOf[Array[Byte]], + classOf[DefaultDecoder], + classOf[DefaultDecoder], + kafkaParams, + topics) + } else { + val messageHandler = new JFunction[MessageAndMetadata[Array[Byte], Array[Byte]], + (Array[Byte], Array[Byte])] { + def call(t1: MessageAndMetadata[Array[Byte], Array[Byte]]): (Array[Byte], Array[Byte]) = + (t1.key(), t1.message()) + } + + val jstream = KafkaUtils.createDirectStream[ + Array[Byte], + Array[Byte], + DefaultDecoder, + DefaultDecoder, + (Array[Byte], Array[Byte])]( + jssc, + classOf[Array[Byte]], + classOf[Array[Byte]], + classOf[DefaultDecoder], + classOf[DefaultDecoder], + classOf[(Array[Byte], Array[Byte])], + kafkaParams, + fromOffsets, + messageHandler) + new JavaPairInputDStream(jstream.inputDStream) + } + } + + def createOffsetRange(topic: String, partition: JInt, fromOffset: JLong, untilOffset: JLong + ): OffsetRange = OffsetRange.create(topic, partition, fromOffset, untilOffset) + + def createTopicAndPartition(topic: String, partition: JInt): TopicAndPartition = + TopicAndPartition(topic, partition) + + def createBroker(host: String, port: JInt): Broker = Broker(host, port) } diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java index 8028e42ffb483..261402856ac5e 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java +++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java @@ -244,7 +244,7 @@ static String quoteForBatchScript(String arg) { boolean needsQuotes = false; for (int i = 0; i < arg.length(); i++) { int c = arg.codePointAt(i); - if (Character.isWhitespace(c) || c == '"' || c == '=') { + if (Character.isWhitespace(c) || c == '"' || c == '=' || c == ',' || c == ';') { needsQuotes = true; break; } @@ -261,15 +261,14 @@ static String quoteForBatchScript(String arg) { quoted.append('"'); break; - case '=': - quoted.append('^'); - break; - default: break; } quoted.appendCodePoint(cp); } + if (arg.codePointAt(arg.length() - 1) == '\\') { + quoted.append("\\"); + } quoted.append("\""); return quoted.toString(); } diff --git a/launcher/src/main/java/org/apache/spark/launcher/Main.java b/launcher/src/main/java/org/apache/spark/launcher/Main.java index 206acfb514d86..929b29a49ed70 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/Main.java +++ b/launcher/src/main/java/org/apache/spark/launcher/Main.java @@ -101,12 +101,9 @@ public static void main(String[] argsArray) throws Exception { * The method quotes all arguments so that spaces are handled as expected. Quotes within arguments * are "double quoted" (which is batch for escaping a quote). This page has more details about * quoting and other batch script fun stuff: http://ss64.com/nt/syntax-esc.html - * - * The command is executed using "cmd /c" and formatted in single line, since that's the - * easiest way to consume this from a batch script (see spark-class2.cmd). */ private static String prepareWindowsCommand(List cmd, Map childEnv) { - StringBuilder cmdline = new StringBuilder("cmd /c \""); + StringBuilder cmdline = new StringBuilder(); for (Map.Entry e : childEnv.entrySet()) { cmdline.append(String.format("set %s=%s", e.getKey(), e.getValue())); cmdline.append(" && "); @@ -115,7 +112,6 @@ private static String prepareWindowsCommand(List cmd, Map buildCommand(Map env) throws IOException { } else if (className.equals("org.apache.spark.executor.MesosExecutorBackend")) { javaOptsKeys.add("SPARK_EXECUTOR_OPTS"); memKey = "SPARK_EXECUTOR_MEMORY"; + } else if (className.equals("org.apache.spark.deploy.ExternalShuffleService")) { + javaOptsKeys.add("SPARK_DAEMON_JAVA_OPTS"); + javaOptsKeys.add("SPARK_SHUFFLE_OPTS"); + memKey = "SPARK_DAEMON_MEMORY"; } else if (className.startsWith("org.apache.spark.tools.")) { String sparkHome = getSparkHome(); File toolsDir = new File(join(File.separator, sparkHome, "tools", "target", diff --git a/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java b/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java index 1ae42eed8a3af..bc513ec9b3d10 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java @@ -74,7 +74,10 @@ public void testWindowsBatchQuoting() { assertEquals("\"a b c\"", quoteForBatchScript("a b c")); assertEquals("\"a \"\"b\"\" c\"", quoteForBatchScript("a \"b\" c")); assertEquals("\"a\"\"b\"\"c\"", quoteForBatchScript("a\"b\"c")); - assertEquals("\"ab^=\"\"cd\"\"\"", quoteForBatchScript("ab=\"cd\"")); + assertEquals("\"ab=\"\"cd\"\"\"", quoteForBatchScript("ab=\"cd\"")); + assertEquals("\"a,b,c\"", quoteForBatchScript("a,b,c")); + assertEquals("\"a;b;c\"", quoteForBatchScript("a;b;c")); + assertEquals("\"a,b,c\\\\\"", quoteForBatchScript("a,b,c\\")); } @Test diff --git a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala index d2ca2e6871e6b..8b4b5fd8af986 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 23956c512c8a6..9db3b29e10d69 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -23,10 +23,9 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.types.{NumericType, StringType, StructType} import org.apache.spark.util.collection.OpenHashMap /** @@ -37,7 +36,11 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { val map = extractParamMap(paramMap) - SchemaUtils.checkColumnType(schema, map(inputCol), StringType) + val inputColName = map(inputCol) + val inputDataType = schema(inputColName).dataType + require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType], + s"The input column $inputColName must be either string type or numeric type, " + + s"but got $inputDataType.") val inputFields = schema.fields val outputColName = map(outputCol) require(inputFields.forall(_.name != outputColName), @@ -51,6 +54,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha /** * :: AlphaComponent :: * A label indexer that maps a string column of labels to an ML column of label indices. + * If the input column is numeric, we cast it to string and index the string values. * The indices are in [0, numLabels), ordered by label frequencies. * So the most frequent label gets index 0. */ @@ -67,7 +71,9 @@ class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase override def fit(dataset: DataFrame, paramMap: ParamMap): StringIndexerModel = { val map = extractParamMap(paramMap) - val counts = dataset.select(map(inputCol)).map(_.getString(0)).countByValue() + val counts = dataset.select(col(map(inputCol)).cast(StringType)) + .map(_.getString(0)) + .countByValue() val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray val model = new StringIndexerModel(this, map, labels) Params.inheritValues(map, this, model) @@ -119,7 +125,8 @@ class StringIndexerModel private[ml] ( val outputColName = map(outputCol) val metadata = NominalAttribute.defaultAttr .withName(outputColName).withValues(labels).toMetadata() - dataset.select(col("*"), indexer(dataset(map(inputCol))).as(outputColName, metadata)) + dataset.select(col("*"), + indexer(dataset(map(inputCol)).cast(StringType)).as(outputColName, metadata)) } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala new file mode 100644 index 0000000000000..0163fa8bd8a5b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.BLAS._ +import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, Row} + +/** + * Params for [[Word2Vec]] and [[Word2VecModel]]. + */ +private[feature] trait Word2VecBase extends Params + with HasInputCol with HasOutputCol with HasMaxIter with HasStepSize with HasSeed { + + /** + * The dimension of the code that you want to transform from words. + */ + final val vectorSize = new IntParam( + this, "vectorSize", "the dimension of codes after transforming from words") + setDefault(vectorSize -> 100) + + /** @group getParam */ + def getVectorSize: Int = getOrDefault(vectorSize) + + /** + * Number of partitions for sentences of words. + */ + final val numPartitions = new IntParam( + this, "numPartitions", "number of partitions for sentences of words") + setDefault(numPartitions -> 1) + + /** @group getParam */ + def getNumPartitions: Int = getOrDefault(numPartitions) + + /** + * The minimum number of times a token must appear to be included in the word2vec model's + * vocabulary. + */ + final val minCount = new IntParam(this, "minCount", "the minimum number of times a token must " + + "appear to be included in the word2vec model's vocabulary") + setDefault(minCount -> 5) + + /** @group getParam */ + def getMinCount: Int = getOrDefault(minCount) + + setDefault(stepSize -> 0.025) + setDefault(maxIter -> 1) + setDefault(seed -> 42L) + + /** + * Validate and transform the input schema. + */ + protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = extractParamMap(paramMap) + SchemaUtils.checkColumnType(schema, map(inputCol), new ArrayType(StringType, true)) + SchemaUtils.appendColumn(schema, map(outputCol), new VectorUDT) + } +} + +/** + * :: AlphaComponent :: + * Word2Vec trains a model of `Map(String, Vector)`, i.e. transforms a word into a code for further + * natural language processing or machine learning process. + */ +@AlphaComponent +final class Word2Vec extends Estimator[Word2VecModel] with Word2VecBase { + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + def setVectorSize(value: Int): this.type = set(vectorSize, value) + + /** @group setParam */ + def setStepSize(value: Double): this.type = set(stepSize, value) + + /** @group setParam */ + def setNumPartitions(value: Int): this.type = set(numPartitions, value) + + /** @group setParam */ + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** @group setParam */ + def setSeed(value: Long): this.type = set(seed, value) + + /** @group setParam */ + def setMinCount(value: Int): this.type = set(minCount, value) + + override def fit(dataset: DataFrame, paramMap: ParamMap): Word2VecModel = { + transformSchema(dataset.schema, paramMap, logging = true) + val map = extractParamMap(paramMap) + val input = dataset.select(map(inputCol)).map { case Row(v: Seq[String]) => v } + val wordVectors = new feature.Word2Vec() + .setLearningRate(map(stepSize)) + .setMinCount(map(minCount)) + .setNumIterations(map(maxIter)) + .setNumPartitions(map(numPartitions)) + .setSeed(map(seed)) + .setVectorSize(map(vectorSize)) + .fit(input) + val model = new Word2VecModel(this, map, wordVectors) + Params.inheritValues(map, this, model) + model + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap) + } +} + +/** + * :: AlphaComponent :: + * Model fitted by [[Word2Vec]]. + */ +@AlphaComponent +class Word2VecModel private[ml] ( + override val parent: Word2Vec, + override val fittingParamMap: ParamMap, + wordVectors: feature.Word2VecModel) + extends Model[Word2VecModel] with Word2VecBase { + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** + * Transform a sentence column to a vector column to represent the whole sentence. The transform + * is performed by averaging all word vectors it contains. + */ + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + transformSchema(dataset.schema, paramMap, logging = true) + val map = extractParamMap(paramMap) + val bWordVectors = dataset.sqlContext.sparkContext.broadcast(wordVectors) + val word2Vec = udf { sentence: Seq[String] => + if (sentence.size == 0) { + Vectors.sparse(map(vectorSize), Array.empty[Int], Array.empty[Double]) + } else { + val cum = Vectors.zeros(map(vectorSize)) + val model = bWordVectors.value.getVectors + for (word <- sentence) { + if (model.contains(word)) { + axpy(1.0, bWordVectors.value.transform(word), cum) + } else { + // pass words which not belong to model + } + } + scal(1.0 / sentence.size, cum) + cum + } + } + dataset.withColumn(map(outputCol), word2Vec(col(map(inputCol)))) + } + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index ddc5907e7facd..014e124e440a3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -24,7 +24,7 @@ import scala.annotation.varargs import scala.collection.mutable import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} -import org.apache.spark.ml.Identifiable +import org.apache.spark.ml.util.Identifiable /** * :: AlphaComponent :: diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index e88c48741e99f..654cd72d53074 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -46,7 +46,10 @@ private[shared] object SharedParamsCodeGen { ParamDesc[String]("outputCol", "output column name"), ParamDesc[Int]("checkpointInterval", "checkpoint interval"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), - ParamDesc[Long]("seed", "random seed", Some("Utils.random.nextLong()"))) + ParamDesc[Long]("seed", "random seed", Some("Utils.random.nextLong()")), + ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter"), + ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms"), + ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization.")) val code = genSharedParams(params) val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala" diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index a860b8834cff9..96d11ed76fa8e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -276,4 +276,55 @@ trait HasSeed extends Params { /** @group getParam */ final def getSeed: Long = getOrDefault(seed) } + +/** + * :: DeveloperApi :: + * Trait for shared param elasticNetParam. + */ +@DeveloperApi +trait HasElasticNetParam extends Params { + + /** + * Param for the ElasticNet mixing parameter. + * @group param + */ + final val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter") + + /** @group getParam */ + final def getElasticNetParam: Double = getOrDefault(elasticNetParam) +} + +/** + * :: DeveloperApi :: + * Trait for shared param tol. + */ +@DeveloperApi +trait HasTol extends Params { + + /** + * Param for the convergence tolerance for iterative algorithms. + * @group param + */ + final val tol: DoubleParam = new DoubleParam(this, "tol", "the convergence tolerance for iterative algorithms") + + /** @group getParam */ + final def getTol: Double = getOrDefault(tol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param stepSize. + */ +@DeveloperApi +trait HasStepSize extends Params { + + /** + * Param for Step size to be used for each iteration of optimization.. + * @group param + */ + final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size to be used for each iteration of optimization.") + + /** @group getParam */ + final def getStepSize: Double = getOrDefault(stepSize) +} // scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 26ca7459c4fdf..f92c6816eb54c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -17,21 +17,29 @@ package org.apache.spark.ml.regression +import scala.collection.mutable + +import breeze.linalg.{norm => brzNorm, DenseVector => BDV} +import breeze.optimize.{LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} +import breeze.optimize.{CachedDiffFunction, DiffFunction} + import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param.{Params, ParamMap} -import org.apache.spark.ml.param.shared._ -import org.apache.spark.mllib.linalg.{BLAS, Vector} -import org.apache.spark.mllib.regression.LinearRegressionWithSGD +import org.apache.spark.ml.param.shared.{HasElasticNetParam, HasMaxIter, HasRegParam, HasTol} +import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.BLAS._ +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.storage.StorageLevel - +import org.apache.spark.util.StatCounter /** * Params for linear regression. */ private[regression] trait LinearRegressionParams extends RegressorParams - with HasRegParam with HasMaxIter - + with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol /** * :: AlphaComponent :: @@ -42,34 +50,119 @@ private[regression] trait LinearRegressionParams extends RegressorParams class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegressionModel] with LinearRegressionParams { - setDefault(regParam -> 0.1, maxIter -> 100) - - /** @group setParam */ + /** + * Set the regularization parameter. + * Default is 0.0. + * @group setParam + */ def setRegParam(value: Double): this.type = set(regParam, value) + setDefault(regParam -> 0.0) + + /** + * Set the ElasticNet mixing parameter. + * For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. + * For 0 < alpha < 1, the penalty is a combination of L1 and L2. + * Default is 0.0 which is an L2 penalty. + * @group setParam + */ + def setElasticNetParam(value: Double): this.type = set(elasticNetParam, value) + setDefault(elasticNetParam -> 0.0) - /** @group setParam */ + /** + * Set the maximal number of iterations. + * Default is 100. + * @group setParam + */ def setMaxIter(value: Int): this.type = set(maxIter, value) + setDefault(maxIter -> 100) + + /** + * Set the convergence tolerance of iterations. + * Smaller value will lead to higher accuracy with the cost of more iterations. + * Default is 1E-6. + * @group setParam + */ + def setTol(value: Double): this.type = set(tol, value) + setDefault(tol -> 1E-6) override protected def train(dataset: DataFrame, paramMap: ParamMap): LinearRegressionModel = { - // Extract columns from data. If dataset is persisted, do not persist oldDataset. - val oldDataset = extractLabeledPoints(dataset, paramMap) + // Extract columns from data. If dataset is persisted, do not persist instances. + val instances = extractLabeledPoints(dataset, paramMap).map { + case LabeledPoint(label: Double, features: Vector) => (label, features) + } val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) { - oldDataset.persist(StorageLevel.MEMORY_AND_DISK) + instances.persist(StorageLevel.MEMORY_AND_DISK) + } + + val (summarizer, statCounter) = instances.treeAggregate( + (new MultivariateOnlineSummarizer, new StatCounter))( { + case ((summarizer: MultivariateOnlineSummarizer, statCounter: StatCounter), + (label: Double, features: Vector)) => + (summarizer.add(features), statCounter.merge(label)) + }, { + case ((summarizer1: MultivariateOnlineSummarizer, statCounter1: StatCounter), + (summarizer2: MultivariateOnlineSummarizer, statCounter2: StatCounter)) => + (summarizer1.merge(summarizer2), statCounter1.merge(statCounter2)) + }) + + val numFeatures = summarizer.mean.size + val yMean = statCounter.mean + val yStd = math.sqrt(statCounter.variance) + + val featuresMean = summarizer.mean.toArray + val featuresStd = summarizer.variance.toArray.map(math.sqrt) + + // Since we implicitly do the feature scaling when we compute the cost function + // to improve the convergence, the effective regParam will be changed. + val effectiveRegParam = paramMap(regParam) / yStd + val effectiveL1RegParam = paramMap(elasticNetParam) * effectiveRegParam + val effectiveL2RegParam = (1.0 - paramMap(elasticNetParam)) * effectiveRegParam + + val costFun = new LeastSquaresCostFun(instances, yStd, yMean, + featuresStd, featuresMean, effectiveL2RegParam) + + val optimizer = if (paramMap(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) { + new BreezeLBFGS[BDV[Double]](paramMap(maxIter), 10, paramMap(tol)) + } else { + new BreezeOWLQN[Int, BDV[Double]](paramMap(maxIter), 10, effectiveL1RegParam, paramMap(tol)) + } + + val initialWeights = Vectors.zeros(numFeatures) + val states = + optimizer.iterations(new CachedDiffFunction(costFun), initialWeights.toBreeze.toDenseVector) + + var state = states.next() + val lossHistory = mutable.ArrayBuilder.make[Double] + + while (states.hasNext) { + lossHistory += state.value + state = states.next() + } + lossHistory += state.value + + // TODO: Based on the sparsity of weights, we may convert the weights to the sparse vector. + // The weights are trained in the scaled space; we're converting them back to + // the original space. + val weights = { + val rawWeights = state.x.toArray.clone() + var i = 0 + while (i < rawWeights.length) { + rawWeights(i) *= { if (featuresStd(i) != 0.0) yStd / featuresStd(i) else 0.0 } + i += 1 + } + Vectors.dense(rawWeights) } - // Train model - val lr = new LinearRegressionWithSGD() - lr.optimizer - .setRegParam(paramMap(regParam)) - .setNumIterations(paramMap(maxIter)) - val model = lr.run(oldDataset) - val lrm = new LinearRegressionModel(this, paramMap, model.weights, model.intercept) + // The intercept in R's GLMNET is computed using closed form after the coefficients are + // converged. See the following discussion for detail. + // http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet + val intercept = yMean - dot(weights, Vectors.dense(featuresMean)) if (handlePersistence) { - oldDataset.unpersist() + instances.unpersist() } - lrm + new LinearRegressionModel(this, paramMap, weights, intercept) } } @@ -88,7 +181,7 @@ class LinearRegressionModel private[ml] ( with LinearRegressionParams { override protected def predict(features: Vector): Double = { - BLAS.dot(features, weights) + intercept + dot(features, weights) + intercept } override protected def copy(): LinearRegressionModel = { @@ -97,3 +190,168 @@ class LinearRegressionModel private[ml] ( m } } + +/** + * LeastSquaresAggregator computes the gradient and loss for a Least-squared loss function, + * as used in linear regression for samples in sparse or dense vector in a online fashion. + * + * Two LeastSquaresAggregator can be merged together to have a summary of loss and gradient of + * the corresponding joint dataset. + * + + * * Compute gradient and loss for a Least-squared loss function, as used in linear regression. + * This is correct for the averaged least squares loss function (mean squared error) + * L = 1/2n ||A weights-y||^2 + * See also the documentation for the precise formulation. + * + * @param weights weights/coefficients corresponding to features + * + * @param updater Updater to be used to update weights after every iteration. + */ +private class LeastSquaresAggregator( + weights: Vector, + labelStd: Double, + labelMean: Double, + featuresStd: Array[Double], + featuresMean: Array[Double]) extends Serializable { + + private var totalCnt: Long = 0 + private var lossSum = 0.0 + private var diffSum = 0.0 + + private val (effectiveWeightsArray: Array[Double], offset: Double, dim: Int) = { + val weightsArray = weights.toArray.clone() + var sum = 0.0 + var i = 0 + while (i < weightsArray.length) { + if (featuresStd(i) != 0.0) { + weightsArray(i) /= featuresStd(i) + sum += weightsArray(i) * featuresMean(i) + } else { + weightsArray(i) = 0.0 + } + i += 1 + } + (weightsArray, -sum + labelMean / labelStd, weightsArray.length) + } + private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray) + + private val gradientSumArray: Array[Double] = Array.ofDim[Double](dim) + + /** + * Add a new training data to this LeastSquaresAggregator, and update the loss and gradient + * of the objective function. + * + * @param label The label for this data point. + * @param data The features for one data point in dense/sparse vector format to be added + * into this aggregator. + * @return This LeastSquaresAggregator object. + */ + def add(label: Double, data: Vector): this.type = { + require(dim == data.size, s"Dimensions mismatch when adding new sample." + + s" Expecting $dim but got ${data.size}.") + + val diff = dot(data, effectiveWeightsVector) - label / labelStd + offset + + if (diff != 0) { + val localGradientSumArray = gradientSumArray + data.foreachActive { (index, value) => + if (featuresStd(index) != 0.0 && value != 0.0) { + localGradientSumArray(index) += diff * value / featuresStd(index) + } + } + lossSum += diff * diff / 2.0 + diffSum += diff + } + + totalCnt += 1 + this + } + + /** + * Merge another LeastSquaresAggregator, and update the loss and gradient + * of the objective function. + * (Note that it's in place merging; as a result, `this` object will be modified.) + * + * @param other The other LeastSquaresAggregator to be merged. + * @return This LeastSquaresAggregator object. + */ + def merge(other: LeastSquaresAggregator): this.type = { + require(dim == other.dim, s"Dimensions mismatch when merging with another " + + s"LeastSquaresAggregator. Expecting $dim but got ${other.dim}.") + + if (other.totalCnt != 0) { + totalCnt += other.totalCnt + lossSum += other.lossSum + diffSum += other.diffSum + + var i = 0 + val localThisGradientSumArray = this.gradientSumArray + val localOtherGradientSumArray = other.gradientSumArray + while (i < dim) { + localThisGradientSumArray(i) += localOtherGradientSumArray(i) + i += 1 + } + } + this + } + + def count: Long = totalCnt + + def loss: Double = lossSum / totalCnt + + def gradient: Vector = { + val result = Vectors.dense(gradientSumArray.clone()) + + val correction = { + val temp = effectiveWeightsArray.clone() + var i = 0 + while (i < temp.length) { + temp(i) *= featuresMean(i) + i += 1 + } + Vectors.dense(temp) + } + + axpy(-diffSum, correction, result) + scal(1.0 / totalCnt, result) + result + } +} + +/** + * LeastSquaresCostFun implements Breeze's DiffFunction[T] for Least Squares cost. + * It returns the loss and gradient with L2 regularization at a particular point (weights). + * It's used in Breeze's convex optimization routines. + */ +private class LeastSquaresCostFun( + data: RDD[(Double, Vector)], + labelStd: Double, + labelMean: Double, + featuresStd: Array[Double], + featuresMean: Array[Double], + effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] { + + override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = { + val w = Vectors.fromBreeze(weights) + + val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator(w, labelStd, + labelMean, featuresStd, featuresMean))( + seqOp = (c, v) => (c, v) match { + case (aggregator, (label, features)) => aggregator.add(label, features) + }, + combOp = (c1, c2) => (c1, c2) match { + case (aggregator1, aggregator2) => aggregator1.merge(aggregator2) + }) + + // regVal is the sum of weight squares for L2 regularization + val norm = brzNorm(weights, 2.0) + val regVal = 0.5 * effectiveL2regParam * norm * norm + + val loss = leastSquaresAggregator.loss + regVal + val gradient = leastSquaresAggregator.gradient + axpy(effectiveL2regParam, w, gradient) + + (loss, gradient.toBreeze.asInstanceOf[BDV[Double]]) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala similarity index 97% rename from mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala rename to mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala index a1d49095c24ac..8a56748ab0a02 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Identifiable.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.ml +package org.apache.spark.ml.util import java.util.UUID diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index d006b39acb213..37bf88b73b911 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -17,16 +17,11 @@ package org.apache.spark.mllib.clustering -import java.util.Random - -import breeze.linalg.{DenseVector => BDV, normalize} - +import breeze.linalg.{DenseVector => BDV} import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaPairRDD import org.apache.spark.graphx._ -import org.apache.spark.graphx.impl.GraphImpl -import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer import org.apache.spark.mllib.linalg.Vector import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils @@ -42,16 +37,9 @@ import org.apache.spark.util.Utils * - "token": instance of a term appearing in a document * - "topic": multinomial distribution over words representing some concept * - * Currently, the underlying implementation uses Expectation-Maximization (EM), implemented - * according to the Asuncion et al. (2009) paper referenced below. - * * References: * - Original LDA paper (journal version): * Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003. - * - This class implements their "smoothed" LDA model. - * - Paper which clearly explains several algorithms, including EM: - * Asuncion, Welling, Smyth, and Teh. - * "On Smoothing and Inference for Topic Models." UAI, 2009. * * @see [[http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation Latent Dirichlet allocation * (Wikipedia)]] @@ -63,10 +51,11 @@ class LDA private ( private var docConcentration: Double, private var topicConcentration: Double, private var seed: Long, - private var checkpointInterval: Int) extends Logging { + private var checkpointInterval: Int, + private var ldaOptimizer: LDAOptimizer) extends Logging { def this() = this(k = 10, maxIterations = 20, docConcentration = -1, topicConcentration = -1, - seed = Utils.random.nextLong(), checkpointInterval = 10) + seed = Utils.random.nextLong(), checkpointInterval = 10, ldaOptimizer = new EMLDAOptimizer) /** * Number of topics to infer. I.e., the number of soft cluster centers. @@ -220,6 +209,32 @@ class LDA private ( this } + + /** LDAOptimizer used to perform the actual calculation */ + def getOptimizer: LDAOptimizer = ldaOptimizer + + /** + * LDAOptimizer used to perform the actual calculation (default = EMLDAOptimizer) + */ + def setOptimizer(optimizer: LDAOptimizer): this.type = { + this.ldaOptimizer = optimizer + this + } + + /** + * Set the LDAOptimizer used to perform the actual calculation by algorithm name. + * Currently "em" is supported. + */ + def setOptimizer(optimizerName: String): this.type = { + this.ldaOptimizer = + optimizerName.toLowerCase match { + case "em" => new EMLDAOptimizer + case other => + throw new IllegalArgumentException(s"Only em is supported but got $other.") + } + this + } + /** * Learn an LDA model using the given dataset. * @@ -229,9 +244,9 @@ class LDA private ( * Document IDs must be unique and >= 0. * @return Inferred LDA model */ - def run(documents: RDD[(Long, Vector)]): DistributedLDAModel = { - val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed, - checkpointInterval) + def run(documents: RDD[(Long, Vector)]): LDAModel = { + val state = ldaOptimizer.initialState(documents, k, getDocConcentration, getTopicConcentration, + seed, checkpointInterval) var iter = 0 val iterationTimes = Array.fill[Double](maxIterations)(0) while (iter < maxIterations) { @@ -241,12 +256,11 @@ class LDA private ( iterationTimes(iter) = elapsedSeconds iter += 1 } - state.graphCheckpointer.deleteAllCheckpoints() - new DistributedLDAModel(state, iterationTimes) + state.getLDAModel(iterationTimes) } /** Java-friendly version of [[run()]] */ - def run(documents: JavaPairRDD[java.lang.Long, Vector]): DistributedLDAModel = { + def run(documents: JavaPairRDD[java.lang.Long, Vector]): LDAModel = { run(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) } } @@ -320,88 +334,10 @@ private[clustering] object LDA { private[clustering] def isTermVertex(v: (VertexId, _)): Boolean = v._1 < 0 - /** - * Optimizer for EM algorithm which stores data + parameter graph, plus algorithm parameters. - * - * @param graph EM graph, storing current parameter estimates in vertex descriptors and - * data (token counts) in edge descriptors. - * @param k Number of topics - * @param vocabSize Number of unique terms - * @param docConcentration "alpha" - * @param topicConcentration "beta" or "eta" - */ - private[clustering] class EMOptimizer( - var graph: Graph[TopicCounts, TokenCount], - val k: Int, - val vocabSize: Int, - val docConcentration: Double, - val topicConcentration: Double, - checkpointInterval: Int) { - - private[LDA] val graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount]( - graph, checkpointInterval) - - def next(): EMOptimizer = { - val eta = topicConcentration - val W = vocabSize - val alpha = docConcentration - - val N_k = globalTopicTotals - val sendMsg: EdgeContext[TopicCounts, TokenCount, (Boolean, TopicCounts)] => Unit = - (edgeContext) => { - // Compute N_{wj} gamma_{wjk} - val N_wj = edgeContext.attr - // E-STEP: Compute gamma_{wjk} (smoothed topic distributions), scaled by token count - // N_{wj}. - val scaledTopicDistribution: TopicCounts = - computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, alpha) *= N_wj - edgeContext.sendToDst((false, scaledTopicDistribution)) - edgeContext.sendToSrc((false, scaledTopicDistribution)) - } - // This is a hack to detect whether we could modify the values in-place. - // TODO: Add zero/seqOp/combOp option to aggregateMessages. (SPARK-5438) - val mergeMsg: ((Boolean, TopicCounts), (Boolean, TopicCounts)) => (Boolean, TopicCounts) = - (m0, m1) => { - val sum = - if (m0._1) { - m0._2 += m1._2 - } else if (m1._1) { - m1._2 += m0._2 - } else { - m0._2 + m1._2 - } - (true, sum) - } - // M-STEP: Aggregation computes new N_{kj}, N_{wk} counts. - val docTopicDistributions: VertexRDD[TopicCounts] = - graph.aggregateMessages[(Boolean, TopicCounts)](sendMsg, mergeMsg) - .mapValues(_._2) - // Update the vertex descriptors with the new counts. - val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges) - graph = newGraph - graphCheckpointer.updateGraph(newGraph) - globalTopicTotals = computeGlobalTopicTotals() - this - } - - /** - * Aggregate distributions over topics from all term vertices. - * - * Note: This executes an action on the graph RDDs. - */ - var globalTopicTotals: TopicCounts = computeGlobalTopicTotals() - - private def computeGlobalTopicTotals(): TopicCounts = { - val numTopics = k - graph.vertices.filter(isTermVertex).values.fold(BDV.zeros[Double](numTopics))(_ += _) - } - - } - /** * Compute gamma_{wjk}, a distribution over topics k. */ - private def computePTopic( + private[clustering] def computePTopic( docTopicCounts: TopicCounts, termTopicCounts: TopicCounts, totalTopicCounts: TopicCounts, @@ -427,49 +363,4 @@ private[clustering] object LDA { // normalize BDV(gamma_wj) /= sum } - - /** - * Compute bipartite term/doc graph. - */ - private def initialState( - docs: RDD[(Long, Vector)], - k: Int, - docConcentration: Double, - topicConcentration: Double, - randomSeed: Long, - checkpointInterval: Int): EMOptimizer = { - // For each document, create an edge (Document -> Term) for each unique term in the document. - val edges: RDD[Edge[TokenCount]] = docs.flatMap { case (docID: Long, termCounts: Vector) => - // Add edges for terms with non-zero counts. - termCounts.toBreeze.activeIterator.filter(_._2 != 0.0).map { case (term, cnt) => - Edge(docID, term2index(term), cnt) - } - } - - val vocabSize = docs.take(1).head._2.size - - // Create vertices. - // Initially, we use random soft assignments of tokens to topics (random gamma). - def createVertices(): RDD[(VertexId, TopicCounts)] = { - val verticesTMP: RDD[(VertexId, TopicCounts)] = - edges.mapPartitionsWithIndex { case (partIndex, partEdges) => - val random = new Random(partIndex + randomSeed) - partEdges.flatMap { edge => - val gamma = normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0) - val sum = gamma * edge.attr - Seq((edge.srcId, sum), (edge.dstId, sum)) - } - } - verticesTMP.reduceByKey(_ + _) - } - - val docTermVertices = createVertices() - - // Partition such that edges are grouped by document - val graph = Graph(docTermVertices, edges) - .partitionBy(PartitionStrategy.EdgePartition1D) - - new EMOptimizer(graph, k, vocabSize, docConcentration, topicConcentration, checkpointInterval) - } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 0a3f21ecee0dc..6cf26445f20a0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -203,7 +203,7 @@ class DistributedLDAModel private ( import LDA._ - private[clustering] def this(state: LDA.EMOptimizer, iterationTimes: Array[Double]) = { + private[clustering] def this(state: EMLDAOptimizer, iterationTimes: Array[Double]) = { this(state.graph, state.globalTopicTotals, state.k, state.vocabSize, state.docConcentration, state.topicConcentration, iterationTimes) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala new file mode 100644 index 0000000000000..ffd72a294c6c6 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import java.util.Random + +import breeze.linalg.{DenseVector => BDV, normalize} + +import org.apache.spark.annotation.Experimental +import org.apache.spark.graphx._ +import org.apache.spark.graphx.impl.GraphImpl +import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.rdd.RDD + +/** + * :: Experimental :: + * + * An LDAOptimizer specifies which optimization/learning/inference algorithm to use, and it can + * hold optimizer-specific parameters for users to set. + */ +@Experimental +trait LDAOptimizer{ + + /* + DEVELOPERS NOTE: + + An LDAOptimizer contains an algorithm for LDA and performs the actual computation, which + stores internal data structure (Graph or Matrix) and other parameters for the algorithm. + The interface is isolated to improve the extensibility of LDA. + */ + + /** + * Initializer for the optimizer. LDA passes the common parameters to the optimizer and + * the internal structure can be initialized properly. + */ + private[clustering] def initialState( + docs: RDD[(Long, Vector)], + k: Int, + docConcentration: Double, + topicConcentration: Double, + randomSeed: Long, + checkpointInterval: Int): LDAOptimizer + + private[clustering] def next(): LDAOptimizer + + private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel +} + +/** + * :: Experimental :: + * + * Optimizer for EM algorithm which stores data + parameter graph, plus algorithm parameters. + * + * Currently, the underlying implementation uses Expectation-Maximization (EM), implemented + * according to the Asuncion et al. (2009) paper referenced below. + * + * References: + * - Original LDA paper (journal version): + * Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003. + * - This class implements their "smoothed" LDA model. + * - Paper which clearly explains several algorithms, including EM: + * Asuncion, Welling, Smyth, and Teh. + * "On Smoothing and Inference for Topic Models." UAI, 2009. + * + */ +@Experimental +class EMLDAOptimizer extends LDAOptimizer{ + + import LDA._ + + /** + * Following fields will only be initialized through initialState method + */ + private[clustering] var graph: Graph[TopicCounts, TokenCount] = null + private[clustering] var k: Int = 0 + private[clustering] var vocabSize: Int = 0 + private[clustering] var docConcentration: Double = 0 + private[clustering] var topicConcentration: Double = 0 + private[clustering] var checkpointInterval: Int = 10 + private var graphCheckpointer: PeriodicGraphCheckpointer[TopicCounts, TokenCount] = null + + /** + * Compute bipartite term/doc graph. + */ + private[clustering] override def initialState( + docs: RDD[(Long, Vector)], + k: Int, + docConcentration: Double, + topicConcentration: Double, + randomSeed: Long, + checkpointInterval: Int): LDAOptimizer = { + // For each document, create an edge (Document -> Term) for each unique term in the document. + val edges: RDD[Edge[TokenCount]] = docs.flatMap { case (docID: Long, termCounts: Vector) => + // Add edges for terms with non-zero counts. + termCounts.toBreeze.activeIterator.filter(_._2 != 0.0).map { case (term, cnt) => + Edge(docID, term2index(term), cnt) + } + } + + val vocabSize = docs.take(1).head._2.size + + // Create vertices. + // Initially, we use random soft assignments of tokens to topics (random gamma). + def createVertices(): RDD[(VertexId, TopicCounts)] = { + val verticesTMP: RDD[(VertexId, TopicCounts)] = + edges.mapPartitionsWithIndex { case (partIndex, partEdges) => + val random = new Random(partIndex + randomSeed) + partEdges.flatMap { edge => + val gamma = normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0) + val sum = gamma * edge.attr + Seq((edge.srcId, sum), (edge.dstId, sum)) + } + } + verticesTMP.reduceByKey(_ + _) + } + + val docTermVertices = createVertices() + + // Partition such that edges are grouped by document + this.graph = Graph(docTermVertices, edges).partitionBy(PartitionStrategy.EdgePartition1D) + this.k = k + this.vocabSize = vocabSize + this.docConcentration = docConcentration + this.topicConcentration = topicConcentration + this.checkpointInterval = checkpointInterval + this.graphCheckpointer = new + PeriodicGraphCheckpointer[TopicCounts, TokenCount](graph, checkpointInterval) + this.globalTopicTotals = computeGlobalTopicTotals() + this + } + + private[clustering] override def next(): EMLDAOptimizer = { + require(graph != null, "graph is null, EMLDAOptimizer not initialized.") + + val eta = topicConcentration + val W = vocabSize + val alpha = docConcentration + + val N_k = globalTopicTotals + val sendMsg: EdgeContext[TopicCounts, TokenCount, (Boolean, TopicCounts)] => Unit = + (edgeContext) => { + // Compute N_{wj} gamma_{wjk} + val N_wj = edgeContext.attr + // E-STEP: Compute gamma_{wjk} (smoothed topic distributions), scaled by token count + // N_{wj}. + val scaledTopicDistribution: TopicCounts = + computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, alpha) *= N_wj + edgeContext.sendToDst((false, scaledTopicDistribution)) + edgeContext.sendToSrc((false, scaledTopicDistribution)) + } + // This is a hack to detect whether we could modify the values in-place. + // TODO: Add zero/seqOp/combOp option to aggregateMessages. (SPARK-5438) + val mergeMsg: ((Boolean, TopicCounts), (Boolean, TopicCounts)) => (Boolean, TopicCounts) = + (m0, m1) => { + val sum = + if (m0._1) { + m0._2 += m1._2 + } else if (m1._1) { + m1._2 += m0._2 + } else { + m0._2 + m1._2 + } + (true, sum) + } + // M-STEP: Aggregation computes new N_{kj}, N_{wk} counts. + val docTopicDistributions: VertexRDD[TopicCounts] = + graph.aggregateMessages[(Boolean, TopicCounts)](sendMsg, mergeMsg) + .mapValues(_._2) + // Update the vertex descriptors with the new counts. + val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges) + graph = newGraph + graphCheckpointer.updateGraph(newGraph) + globalTopicTotals = computeGlobalTopicTotals() + this + } + + /** + * Aggregate distributions over topics from all term vertices. + * + * Note: This executes an action on the graph RDDs. + */ + private[clustering] var globalTopicTotals: TopicCounts = null + + private def computeGlobalTopicTotals(): TopicCounts = { + val numTopics = k + graph.vertices.filter(isTermVertex).values.fold(BDV.zeros[Double](numTopics))(_ += _) + } + + private[clustering] override def getLDAModel(iterationTimes: Array[Double]): LDAModel = { + require(graph != null, "graph is null, EMLDAOptimizer not initialized.") + this.graphCheckpointer.deleteAllCheckpoints() + new DistributedLDAModel(this, iterationTimes) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 166c00cff634d..188d1e542b5b5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -52,7 +52,7 @@ sealed trait Vector extends Serializable { override def equals(other: Any): Boolean = { other match { - case v2: Vector => { + case v2: Vector => if (this.size != v2.size) return false (this, v2) match { case (s1: SparseVector, s2: SparseVector) => @@ -63,20 +63,28 @@ sealed trait Vector extends Serializable { Vectors.equals(0 until d1.size, d1.values, s1.indices, s1.values) case (_, _) => util.Arrays.equals(this.toArray, v2.toArray) } - } case _ => false } } + /** + * Returns a hash code value for the vector. The hash code is based on its size and its nonzeros + * in the first 16 entries, using a hash algorithm similar to [[java.util.Arrays.hashCode]]. + */ override def hashCode(): Int = { - var result: Int = size + 31 - this.foreachActive { case (index, value) => - // ignore explict 0 for comparison between sparse and dense - if (value != 0) { - result = 31 * result + index - // refer to {@link java.util.Arrays.equals} for hash algorithm - val bits = java.lang.Double.doubleToLongBits(value) - result = 31 * result + (bits ^ (bits >>> 32)).toInt + // This is a reference implementation. It calls return in foreachActive, which is slow. + // Subclasses should override it with optimized implementation. + var result: Int = 31 + size + this.foreachActive { (index, value) => + if (index < 16) { + // ignore explicit 0 for comparison between sparse and dense + if (value != 0) { + result = 31 * result + index + val bits = java.lang.Double.doubleToLongBits(value) + result = 31 * result + (bits ^ (bits >>> 32)).toInt + } + } else { + return result } } result @@ -85,7 +93,7 @@ sealed trait Vector extends Serializable { /** * Converts the instance to a breeze vector. */ - private[mllib] def toBreeze: BV[Double] + private[spark] def toBreeze: BV[Double] /** * Gets the value of the ith element. @@ -108,6 +116,40 @@ sealed trait Vector extends Serializable { * with type `Double`. */ private[spark] def foreachActive(f: (Int, Double) => Unit) + + /** + * Number of active entries. An "active entry" is an element which is explicitly stored, + * regardless of its value. Note that inactive entries have value 0. + */ + def numActives: Int + + /** + * Number of nonzero elements. This scans all active values and count nonzeros. + */ + def numNonzeros: Int + + /** + * Converts this vector to a sparse vector with all explicit zeros removed. + */ + def toSparse: SparseVector + + /** + * Converts this vector to a dense vector. + */ + def toDense: DenseVector = new DenseVector(this.toArray) + + /** + * Returns a vector in either dense or sparse format, whichever uses less storage. + */ + def compressed: Vector = { + val nnz = numNonzeros + // A dense vector needs 8 * size + 8 bytes, while a sparse vector needs 12 * nnz + 20 bytes. + if (1.5 * (nnz + 1.0) < size) { + toSparse + } else { + toDense + } + } } /** @@ -284,7 +326,7 @@ object Vectors { /** * Creates a vector instance from a breeze vector. */ - private[mllib] def fromBreeze(breezeVector: BV[Double]): Vector = { + private[spark] def fromBreeze(breezeVector: BV[Double]): Vector = { breezeVector match { case v: BDV[Double] => if (v.offset == 0 && v.stride == 1 && v.length == v.data.length) { @@ -317,7 +359,7 @@ object Vectors { case SparseVector(n, ids, vs) => vs case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) } - val size = values.size + val size = values.length if (p == 1) { var sum = 0.0 @@ -371,8 +413,8 @@ object Vectors { val v1Indices = v1.indices val v2Values = v2.values val v2Indices = v2.indices - val nnzv1 = v1Indices.size - val nnzv2 = v2Indices.size + val nnzv1 = v1Indices.length + val nnzv2 = v2Indices.length var kv1 = 0 var kv2 = 0 @@ -401,7 +443,7 @@ object Vectors { case (DenseVector(vv1), DenseVector(vv2)) => var kv = 0 - val sz = vv1.size + val sz = vv1.length while (kv < sz) { val score = vv1(kv) - vv2(kv) squaredDistance += score * score @@ -422,7 +464,7 @@ object Vectors { var kv2 = 0 val indices = v1.indices var squaredDistance = 0.0 - val nnzv1 = indices.size + val nnzv1 = indices.length val nnzv2 = v2.size var iv1 = if (nnzv1 > 0) indices(kv1) else -1 @@ -451,8 +493,8 @@ object Vectors { v1Values: Array[Double], v2Indices: IndexedSeq[Int], v2Values: Array[Double]): Boolean = { - val v1Size = v1Values.size - val v2Size = v2Values.size + val v1Size = v1Values.length + val v2Size = v2Values.length var k1 = 0 var k2 = 0 var allEqual = true @@ -483,7 +525,7 @@ class DenseVector(val values: Array[Double]) extends Vector { override def toArray: Array[Double] = values - private[mllib] override def toBreeze: BV[Double] = new BDV[Double](values) + private[spark] override def toBreeze: BV[Double] = new BDV[Double](values) override def apply(i: Int): Double = values(i) @@ -493,7 +535,7 @@ class DenseVector(val values: Array[Double]) extends Vector { private[spark] override def foreachActive(f: (Int, Double) => Unit) = { var i = 0 - val localValuesSize = values.size + val localValuesSize = values.length val localValues = values while (i < localValuesSize) { @@ -501,6 +543,50 @@ class DenseVector(val values: Array[Double]) extends Vector { i += 1 } } + + override def hashCode(): Int = { + var result: Int = 31 + size + var i = 0 + val end = math.min(values.length, 16) + while (i < end) { + val v = values(i) + if (v != 0.0) { + result = 31 * result + i + val bits = java.lang.Double.doubleToLongBits(values(i)) + result = 31 * result + (bits ^ (bits >>> 32)).toInt + } + i += 1 + } + result + } + + override def numActives: Int = size + + override def numNonzeros: Int = { + // same as values.count(_ != 0.0) but faster + var nnz = 0 + values.foreach { v => + if (v != 0.0) { + nnz += 1 + } + } + nnz + } + + override def toSparse: SparseVector = { + val nnz = numNonzeros + val ii = new Array[Int](nnz) + val vv = new Array[Double](nnz) + var k = 0 + foreachActive { (i, v) => + if (v != 0) { + ii(k) = i + vv(k) = v + k += 1 + } + } + new SparseVector(size, ii, vv) + } } object DenseVector { @@ -522,8 +608,8 @@ class SparseVector( val values: Array[Double]) extends Vector { require(indices.length == values.length, "Sparse vectors require that the dimension of the" + - s" indices match the dimension of the values. You provided ${indices.size} indices and " + - s" ${values.size} values.") + s" indices match the dimension of the values. You provided ${indices.length} indices and " + + s" ${values.length} values.") override def toString: String = s"($size,${indices.mkString("[", ",", "]")},${values.mkString("[", ",", "]")})" @@ -543,11 +629,11 @@ class SparseVector( new SparseVector(size, indices.clone(), values.clone()) } - private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size) + private[spark] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size) private[spark] override def foreachActive(f: (Int, Double) => Unit) = { var i = 0 - val localValuesSize = values.size + val localValuesSize = values.length val localIndices = indices val localValues = values @@ -556,6 +642,59 @@ class SparseVector( i += 1 } } + + override def hashCode(): Int = { + var result: Int = 31 + size + val end = values.length + var continue = true + var k = 0 + while ((k < end) & continue) { + val i = indices(k) + if (i < 16) { + val v = values(k) + if (v != 0.0) { + result = 31 * result + i + val bits = java.lang.Double.doubleToLongBits(v) + result = 31 * result + (bits ^ (bits >>> 32)).toInt + } + } else { + continue = false + } + k += 1 + } + result + } + + override def numActives: Int = values.length + + override def numNonzeros: Int = { + var nnz = 0 + values.foreach { v => + if (v != 0.0) { + nnz += 1 + } + } + nnz + } + + override def toSparse: SparseVector = { + val nnz = numNonzeros + if (nnz == numActives) { + this + } else { + val ii = new Array[Int](nnz) + val vv = new Array[Double](nnz) + var k = 0 + foreachActive { (i, v) => + if (v != 0.0) { + ii(k) = i + vv(k) = v + k += 1 + } + } + new SparseVector(size, ii, vv) + } + } } object SparseVector { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala index 8bfa0d2b64995..240baeb5a158b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala @@ -37,7 +37,11 @@ abstract class Gradient extends Serializable { * * @return (gradient: Vector, loss: Double) */ - def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) + def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { + val gradient = Vectors.zeros(weights.size) + val loss = compute(data, label, weights, gradient) + (gradient, loss) + } /** * Compute the gradient and loss given the features of a single data point, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index ef6eccd90711a..efedc112d380e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -17,6 +17,7 @@ package org.apache.spark.mllib.optimization +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import breeze.linalg.{DenseVector => BDV} @@ -164,7 +165,7 @@ object LBFGS extends Logging { regParam: Double, initialWeights: Vector): (Vector, Array[Double]) = { - val lossHistory = new ArrayBuffer[Double](maxNumIterations) + val lossHistory = mutable.ArrayBuilder.make[Double] val numExamples = data.count() @@ -181,17 +182,19 @@ object LBFGS extends Logging { * and regVal is the regularization value computed in the previous iteration as well. */ var state = states.next() - while(states.hasNext) { - lossHistory.append(state.value) + while (states.hasNext) { + lossHistory += state.value state = states.next() } - lossHistory.append(state.value) + lossHistory += state.value val weights = Vectors.fromBreeze(state.x) + val lossHistoryArray = lossHistory.result() + logInfo("LBFGS.runLBFGS finished. Last 10 losses %s".format( - lossHistory.takeRight(10).mkString(", "))) + lossHistoryArray.takeRight(10).mkString(", "))) - (weights, lossHistory.toArray) + (weights, lossHistoryArray) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/PearsonCorrelation.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/PearsonCorrelation.scala index 23b291eee070b..8a821d1b23bab 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/PearsonCorrelation.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/PearsonCorrelation.scala @@ -101,7 +101,7 @@ private[stat] object PearsonCorrelation extends Correlation with Logging { Matrices.fromBreeze(cov) } - private def closeToZero(value: Double, threshhold: Double = 1e-12): Boolean = { - math.abs(value) <= threshhold + private def closeToZero(value: Double, threshold: Double = 1e-12): Boolean = { + math.abs(value) <= threshold } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index 0e31c7ed58df8..deac390130128 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -177,9 +177,10 @@ object GradientBoostedTrees extends Logging { treeStrategy.assertValid() // Cache input - if (input.getStorageLevel == StorageLevel.NONE) { + val persistedInput = if (input.getStorageLevel == StorageLevel.NONE) { input.persist(StorageLevel.MEMORY_AND_DISK) - } + true + } else false timer.stop("init") @@ -265,6 +266,9 @@ object GradientBoostedTrees extends Logging { logInfo("Internal timing for DecisionTree:") logInfo(s"$timer") + + if (persistedInput) input.unpersist() + if (validate) { new GradientBoostedTreesModel( boostingStrategy.treeStrategy.algo, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala index c9d33787b0bb5..d7bb943e84f53 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala @@ -56,6 +56,10 @@ object LinearDataGenerator { } /** + * For compatibility, the generated data without specifying the mean and variance + * will have zero mean and variance of (1.0/3.0) since the original output range is + * [-1, 1] with uniform distribution, and the variance of uniform distribution + * is (b - a)^2^ / 12 which will be (1.0/3.0) * * @param intercept Data intercept * @param weights Weights to be applied. @@ -70,10 +74,47 @@ object LinearDataGenerator { nPoints: Int, seed: Int, eps: Double = 0.1): Seq[LabeledPoint] = { + generateLinearInput(intercept, weights, + Array.fill[Double](weights.size)(0.0), + Array.fill[Double](weights.size)(1.0 / 3.0), + nPoints, seed, eps)} + + /** + * + * @param intercept Data intercept + * @param weights Weights to be applied. + * @param xMean the mean of the generated features. Lots of time, if the features are not properly + * standardized, the algorithm with poor implementation will have difficulty + * to converge. + * @param xVariance the variance of the generated features. + * @param nPoints Number of points in sample. + * @param seed Random seed + * @param eps Epsilon scaling factor. + * @return Seq of input. + */ + def generateLinearInput( + intercept: Double, + weights: Array[Double], + xMean: Array[Double], + xVariance: Array[Double], + nPoints: Int, + seed: Int, + eps: Double): Seq[LabeledPoint] = { val rnd = new Random(seed) val x = Array.fill[Array[Double]](nPoints)( - Array.fill[Double](weights.length)(2 * rnd.nextDouble - 1.0)) + Array.fill[Double](weights.length)(rnd.nextDouble)) + + x.map(vector => { + // This doesn't work if `vector` is a sparse vector. + val vectorArray = vector.toArray + var i = 0 + while (i < vectorArray.size) { + vectorArray(i) = (vectorArray(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i) + i += 1 + } + }) + val y = x.map { xi => blas.ddot(weights.length, xi, 1, weights, 1) + intercept + eps * rnd.nextGaussian() } diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index dc10aa67c7c1f..fbe171b4b1ab1 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -88,7 +88,7 @@ public void distributedLDAModel() { .setMaxIterations(5) .setSeed(12345); - DistributedLDAModel model = lda.run(corpus); + DistributedLDAModel model = (DistributedLDAModel)lda.run(corpus); // Check: basic parameters LocalLDAModel localModel = model.toLocal(); diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 00b5d094d82f1..b6939e5870410 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -49,4 +49,23 @@ class StringIndexerSuite extends FunSuite with MLlibTestSparkContext { val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) assert(output === expected) } + + test("StringIndexer with a numeric input column") { + val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2) + val df = sqlContext.createDataFrame(data).toDF("id", "label") + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + .fit(df) + val transformed = indexer.transform(df) + val attr = Attribute.fromStructField(transformed.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attr.values.get === Array("100", "300", "200")) + val output = transformed.select("id", "labelIndex").map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // 100 -> 0, 200 -> 2, 300 -> 1 + val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) + assert(output === expected) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala new file mode 100644 index 0000000000000..03ba86670d453 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{Row, SQLContext} + +class Word2VecSuite extends FunSuite with MLlibTestSparkContext { + + test("Word2Vec") { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + val sentence = "a b " * 100 + "a c " * 10 + val numOfWords = sentence.split(" ").size + val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) + + val codes = Map( + "a" -> Array(-0.2811822295188904,-0.6356269121170044,-0.3020961284637451), + "b" -> Array(1.0309048891067505,-1.29472815990448,0.22276712954044342), + "c" -> Array(-0.08456747233867645,0.5137411952018738,0.11731560528278351) + ) + + val expected = doc.map { sentence => + Vectors.dense(sentence.map(codes.apply).reduce((word1, word2) => + word1.zip(word2).map { case (v1, v2) => v1 + v2 } + ).map(_ / numOfWords)) + } + + val docDF = doc.zip(expected).toDF("text", "expected") + + val model = new Word2Vec() + .setVectorSize(3) + .setInputCol("text") + .setOutputCol("result") + .fit(docDF) + + model.transform(docDF).select("result", "expected").collect().foreach { + case Row(vector1: Vector, vector2: Vector) => + assert(vector1 ~== vector2 absTol 1E-5, "Transformed vector is different with expected.") + } + } +} + diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index bbb44c3e2dfc2..80323ef5201a6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -19,47 +19,149 @@ package org.apache.spark.ml.regression import org.scalatest.FunSuite -import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.mllib.linalg.DenseVector +import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.{Row, SQLContext, DataFrame} class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { @transient var sqlContext: SQLContext = _ @transient var dataset: DataFrame = _ + /** + * In `LinearRegressionSuite`, we will make sure that the model trained by SparkML + * is the same as the one trained by R's glmnet package. The following instruction + * describes how to reproduce the data in R. + * + * import org.apache.spark.mllib.util.LinearDataGenerator + * val data = + * sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2), 10000, 42), 2) + * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).saveAsTextFile("path") + */ override def beforeAll(): Unit = { super.beforeAll() sqlContext = new SQLContext(sc) dataset = sqlContext.createDataFrame( - sc.parallelize(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42), 2)) + sc.parallelize(LinearDataGenerator.generateLinearInput( + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2)) } - test("linear regression: default params") { - val lr = new LinearRegression - assert(lr.getLabelCol == "label") - val model = lr.fit(dataset) - model.transform(dataset) - .select("label", "prediction") - .collect() - // Check defaults - assert(model.getFeaturesCol == "features") - assert(model.getPredictionCol == "prediction") + test("linear regression with intercept without regularization") { + val trainer = new LinearRegression + val model = trainer.fit(dataset) + + /** + * Using the following R code to load the data and train the model using glmnet package. + * + * library("glmnet") + * data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + * features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3))) + * label <- as.numeric(data$V1) + * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0)) + * > weights + * 3 x 1 sparse Matrix of class "dgCMatrix" + * s0 + * (Intercept) 6.300528 + * as.numeric.data.V2. 4.701024 + * as.numeric.data.V3. 7.198257 + */ + val interceptR = 6.298698 + val weightsR = Array(4.700706, 7.199082) + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.weights(0) ~== weightsR(0) relTol 1E-3) + assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + + test("linear regression with intercept with L1 regularization") { + val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) + val model = trainer.fit(dataset) + + /** + * weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57)) + * > weights + * 3 x 1 sparse Matrix of class "dgCMatrix" + * s0 + * (Intercept) 6.311546 + * as.numeric.data.V2. 2.123522 + * as.numeric.data.V3. 4.605651 + */ + val interceptR = 6.243000 + val weightsR = Array(4.024821, 6.679841) + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.weights(0) ~== weightsR(0) relTol 1E-3) + assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } } - test("linear regression with setters") { - // Set params, train, and check as many as we can. - val lr = new LinearRegression() - .setMaxIter(10) - .setRegParam(1.0) - val model = lr.fit(dataset) - assert(model.fittingParamMap.get(lr.maxIter).get === 10) - assert(model.fittingParamMap.get(lr.regParam).get === 1.0) - - // Call fit() with new params, and check as many as we can. - val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.predictionCol -> "thePred") - assert(model2.fittingParamMap.get(lr.maxIter).get === 5) - assert(model2.fittingParamMap.get(lr.regParam).get === 0.1) - assert(model2.getPredictionCol == "thePred") + test("linear regression with intercept with L2 regularization") { + val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3) + val model = trainer.fit(dataset) + + /** + * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3)) + * > weights + * 3 x 1 sparse Matrix of class "dgCMatrix" + * s0 + * (Intercept) 6.328062 + * as.numeric.data.V2. 3.222034 + * as.numeric.data.V3. 4.926260 + */ + val interceptR = 5.269376 + val weightsR = Array(3.736216, 5.712356) + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.weights(0) ~== weightsR(0) relTol 1E-3) + assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + + test("linear regression with intercept with ElasticNet regularization") { + val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) + val model = trainer.fit(dataset) + + /** + * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6)) + * > weights + * 3 x 1 sparse Matrix of class "dgCMatrix" + * s0 + * (Intercept) 6.324108 + * as.numeric.data.V2. 3.168435 + * as.numeric.data.V3. 5.200403 + */ + val interceptR = 5.696056 + val weightsR = Array(3.670489, 6.001122) + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.weights(0) ~== weightsR(0) relTol 1E-3) + assert(model.weights(1) ~== weightsR(1) relTol 1E-3) + + model.transform(dataset).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = + features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept + assert(prediction1 ~== prediction2 relTol 1E-5) + } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index cc747dabb9968..41ec794146c69 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -68,7 +68,7 @@ class LDASuite extends FunSuite with MLlibTestSparkContext { .setSeed(12345) val corpus = sc.parallelize(tinyCorpus, 2) - val model: DistributedLDAModel = lda.run(corpus) + val model: DistributedLDAModel = lda.run(corpus).asInstanceOf[DistributedLDAModel] // Check: basic parameters val localModel = model.toLocal diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 2839c4c289b2d..24755e9ff46fc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -270,4 +270,48 @@ class VectorsSuite extends FunSuite { assert(Vectors.norm(sv, 3.7) ~== math.pow(sv.toArray.foldLeft(0.0)((a, v) => a + math.pow(math.abs(v), 3.7)), 1.0 / 3.7) relTol 1E-8) } + + test("Vector numActive and numNonzeros") { + val dv = Vectors.dense(0.0, 2.0, 3.0, 0.0) + assert(dv.numActives === 4) + assert(dv.numNonzeros === 2) + + val sv = Vectors.sparse(4, Array(0, 1, 2), Array(0.0, 2.0, 3.0)) + assert(sv.numActives === 3) + assert(sv.numNonzeros === 2) + } + + test("Vector toSparse and toDense") { + val dv0 = Vectors.dense(0.0, 2.0, 3.0, 0.0) + assert(dv0.toDense === dv0) + val dv0s = dv0.toSparse + assert(dv0s.numActives === 2) + assert(dv0s === dv0) + + val sv0 = Vectors.sparse(4, Array(0, 1, 2), Array(0.0, 2.0, 3.0)) + assert(sv0.toDense === sv0) + val sv0s = sv0.toSparse + assert(sv0s.numActives === 2) + assert(sv0s === sv0) + } + + test("Vector.compressed") { + val dv0 = Vectors.dense(1.0, 2.0, 3.0, 0.0) + val dv0c = dv0.compressed.asInstanceOf[DenseVector] + assert(dv0c === dv0) + + val dv1 = Vectors.dense(0.0, 2.0, 0.0, 0.0) + val dv1c = dv1.compressed.asInstanceOf[SparseVector] + assert(dv1 === dv1c) + assert(dv1c.numActives === 1) + + val sv0 = Vectors.sparse(4, Array(1, 2), Array(2.0, 0.0)) + val sv0c = sv0.compressed.asInstanceOf[SparseVector] + assert(sv0 === sv0c) + assert(sv0c.numActives === 1) + + val sv1 = Vectors.sparse(4, Array(0, 1, 2), Array(1.0, 2.0, 3.0)) + val sv1c = sv1.compressed.asInstanceOf[DenseVector] + assert(sv1 === sv1c) + } } diff --git a/network/common/pom.xml b/network/common/pom.xml index 22c738bde6d42..0c3147761cfc5 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -95,7 +95,6 @@ org.apache.maven.plugins maven-jar-plugin - 2.2 test-jar-on-test-compile diff --git a/network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java b/network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java new file mode 100644 index 0000000000000..36d655017fb0d --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/ByteUnit.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.util; + +public enum ByteUnit { + BYTE (1), + KiB (1024L), + MiB ((long) Math.pow(1024L, 2L)), + GiB ((long) Math.pow(1024L, 3L)), + TiB ((long) Math.pow(1024L, 4L)), + PiB ((long) Math.pow(1024L, 5L)); + + private ByteUnit(long multiplier) { + this.multiplier = multiplier; + } + + // Interpret the provided number (d) with suffix (u) as this unit type. + // E.g. KiB.interpret(1, MiB) interprets 1MiB as its KiB representation = 1024k + public long convertFrom(long d, ByteUnit u) { + return u.convertTo(d, this); + } + + // Convert the provided number (d) interpreted as this unit type to unit type (u). + public long convertTo(long d, ByteUnit u) { + if (multiplier > u.multiplier) { + long ratio = multiplier / u.multiplier; + if (Long.MAX_VALUE / ratio < d) { + throw new IllegalArgumentException("Conversion of " + d + " exceeds Long.MAX_VALUE in " + + name() + ". Try a larger unit (e.g. MiB instead of KiB)"); + } + return d * ratio; + } else { + // Perform operations in this order to avoid potential overflow + // when computing d * multiplier + return d / (u.multiplier / multiplier); + } + } + + public double toBytes(long d) { + if (d < 0) { + throw new IllegalArgumentException("Negative size value. Size must be positive: " + d); + } + return d * multiplier; + } + + public long toKiB(long d) { return convertTo(d, KiB); } + public long toMiB(long d) { return convertTo(d, MiB); } + public long toGiB(long d) { return convertTo(d, GiB); } + public long toTiB(long d) { return convertTo(d, TiB); } + public long toPiB(long d) { return convertTo(d, PiB); } + + private final long multiplier; +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index b6fbace509a0e..6b514aaa1290d 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -126,7 +126,7 @@ private static boolean isSymlink(File file) throws IOException { return !fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile()); } - private static ImmutableMap timeSuffixes = + private static final ImmutableMap timeSuffixes = ImmutableMap.builder() .put("us", TimeUnit.MICROSECONDS) .put("ms", TimeUnit.MILLISECONDS) @@ -137,6 +137,21 @@ private static boolean isSymlink(File file) throws IOException { .put("d", TimeUnit.DAYS) .build(); + private static final ImmutableMap byteSuffixes = + ImmutableMap.builder() + .put("b", ByteUnit.BYTE) + .put("k", ByteUnit.KiB) + .put("kb", ByteUnit.KiB) + .put("m", ByteUnit.MiB) + .put("mb", ByteUnit.MiB) + .put("g", ByteUnit.GiB) + .put("gb", ByteUnit.GiB) + .put("t", ByteUnit.TiB) + .put("tb", ByteUnit.TiB) + .put("p", ByteUnit.PiB) + .put("pb", ByteUnit.PiB) + .build(); + /** * Convert a passed time string (e.g. 50s, 100ms, or 250us) to a time count for * internal use. If no suffix is provided a direct conversion is attempted. @@ -145,16 +160,14 @@ private static long parseTimeString(String str, TimeUnit unit) { String lower = str.toLowerCase().trim(); try { - String suffix; - long val; Matcher m = Pattern.compile("(-?[0-9]+)([a-z]+)?").matcher(lower); - if (m.matches()) { - val = Long.parseLong(m.group(1)); - suffix = m.group(2); - } else { + if (!m.matches()) { throw new NumberFormatException("Failed to parse time string: " + str); } + long val = Long.parseLong(m.group(1)); + String suffix = m.group(2); + // Check for invalid suffixes if (suffix != null && !timeSuffixes.containsKey(suffix)) { throw new NumberFormatException("Invalid suffix: \"" + suffix + "\""); @@ -164,7 +177,7 @@ private static long parseTimeString(String str, TimeUnit unit) { return unit.convert(val, suffix != null ? timeSuffixes.get(suffix) : unit); } catch (NumberFormatException e) { String timeError = "Time must be specified as seconds (s), " + - "milliseconds (ms), microseconds (us), minutes (m or min) hour (h), or day (d). " + + "milliseconds (ms), microseconds (us), minutes (m or min), hour (h), or day (d). " + "E.g. 50s, 100ms, or 250us."; throw new NumberFormatException(timeError + "\n" + e.getMessage()); @@ -186,5 +199,83 @@ public static long timeStringAsMs(String str) { public static long timeStringAsSec(String str) { return parseTimeString(str, TimeUnit.SECONDS); } + + /** + * Convert a passed byte string (e.g. 50b, 100kb, or 250mb) to a ByteUnit for + * internal use. If no suffix is provided a direct conversion of the provided default is + * attempted. + */ + private static long parseByteString(String str, ByteUnit unit) { + String lower = str.toLowerCase().trim(); + + try { + Matcher m = Pattern.compile("([0-9]+)([a-z]+)?").matcher(lower); + Matcher fractionMatcher = Pattern.compile("([0-9]+\\.[0-9]+)([a-z]+)?").matcher(lower); + + if (m.matches()) { + long val = Long.parseLong(m.group(1)); + String suffix = m.group(2); + + // Check for invalid suffixes + if (suffix != null && !byteSuffixes.containsKey(suffix)) { + throw new NumberFormatException("Invalid suffix: \"" + suffix + "\""); + } + + // If suffix is valid use that, otherwise none was provided and use the default passed + return unit.convertFrom(val, suffix != null ? byteSuffixes.get(suffix) : unit); + } else if (fractionMatcher.matches()) { + throw new NumberFormatException("Fractional values are not supported. Input was: " + + fractionMatcher.group(1)); + } else { + throw new NumberFormatException("Failed to parse byte string: " + str); + } + + } catch (NumberFormatException e) { + String timeError = "Size must be specified as bytes (b), " + + "kibibytes (k), mebibytes (m), gibibytes (g), tebibytes (t), or pebibytes(p). " + + "E.g. 50b, 100k, or 250m."; + throw new NumberFormatException(timeError + "\n" + e.getMessage()); + } + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to bytes for + * internal use. + * + * If no suffix is provided, the passed number is assumed to be in bytes. + */ + public static long byteStringAsBytes(String str) { + return parseByteString(str, ByteUnit.BYTE); + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to kibibytes for + * internal use. + * + * If no suffix is provided, the passed number is assumed to be in kibibytes. + */ + public static long byteStringAsKb(String str) { + return parseByteString(str, ByteUnit.KiB); + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to mebibytes for + * internal use. + * + * If no suffix is provided, the passed number is assumed to be in mebibytes. + */ + public static long byteStringAsMb(String str) { + return parseByteString(str, ByteUnit.MiB); + } + + /** + * Convert a passed byte string (e.g. 50b, 100k, or 250m) to gibibytes for + * internal use. + * + * If no suffix is provided, the passed number is assumed to be in gibibytes. + */ + public static long byteStringAsGb(String str) { + return parseByteString(str, ByteUnit.GiB); + } } diff --git a/pom.xml b/pom.xml index 9fbce1d639d8b..c85c5feeaf383 100644 --- a/pom.xml +++ b/pom.xml @@ -97,6 +97,7 @@ sql/catalyst sql/core sql/hive + unsafe assembly external/twitter external/flume @@ -1082,7 +1083,7 @@ org.apache.maven.plugins maven-enforcer-plugin - 1.3.1 + 1.4 enforce-versions @@ -1105,7 +1106,7 @@ org.codehaus.mojo build-helper-maven-plugin - 1.8 + 1.9.1 net.alchim31.maven @@ -1176,7 +1177,7 @@ org.apache.maven.plugins maven-compiler-plugin - 3.1 + 3.3 ${java.version} ${java.version} @@ -1189,7 +1190,7 @@ org.apache.maven.plugins maven-surefire-plugin - 2.18 + 2.18.1 @@ -1215,6 +1216,7 @@ false false true + true false @@ -1260,17 +1262,17 @@ org.apache.maven.plugins maven-jar-plugin - 2.4 + 2.6 org.apache.maven.plugins maven-antrun-plugin - 1.7 + 1.8 org.apache.maven.plugins maven-source-plugin - 2.2.1 + 2.4 true @@ -1287,7 +1289,7 @@ org.apache.maven.plugins maven-clean-plugin - 2.5 + 2.6.1 @@ -1305,7 +1307,27 @@ org.apache.maven.plugins maven-javadoc-plugin - 2.10.1 + 2.10.3 + + + org.codehaus.mojo + exec-maven-plugin + 1.4.0 + + + org.apache.maven.plugins + maven-assembly-plugin + 2.5.3 + + + org.apache.maven.plugins + maven-install-plugin + 2.5.2 + + + org.apache.maven.plugins + maven-deploy-plugin + 2.8.2 @@ -1315,7 +1337,7 @@ org.apache.maven.plugins maven-dependency-plugin - 2.9 + 2.10 test-compile @@ -1334,7 +1356,7 @@ org.codehaus.gmavenplus gmavenplus-plugin - 1.2 + 1.5 process-test-classes @@ -1359,7 +1381,7 @@ org.apache.maven.plugins maven-shade-plugin - 2.2 + 2.3 false diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 7ef363a2f07ad..3beafa158eb97 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -72,6 +72,22 @@ object MimaExcludes { // SPARK-6703 Add getOrCreate method to SparkContext ProblemFilters.exclude[IncompatibleResultTypeProblem] ("org.apache.spark.SparkContext.org$apache$spark$SparkContext$$activeContext") + )++ Seq( + // SPARK-7090 Introduce LDAOptimizer to LDA to further improve extensibility + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.mllib.clustering.LDA$EMOptimizer") + ) ++ Seq( + // SPARK-6756 add toSparse, toDense, numActives, numNonzeros, and compressed to Vector + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.compressed"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.toDense"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.numNonzeros"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.toSparse"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Vector.numActives") ) case v if v.startsWith("1.3") => diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 09b4976d10c26..b7dbcd9bc562a 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -34,11 +34,11 @@ object BuildCommons { val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, sql, networkCommon, networkShuffle, streaming, streamingFlumeSink, streamingFlume, streamingKafka, - streamingMqtt, streamingTwitter, streamingZeromq, launcher) = + streamingMqtt, streamingTwitter, streamingZeromq, launcher, unsafe) = Seq("bagel", "catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl", "sql", "network-common", "network-shuffle", "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter", - "streaming-zeromq", "launcher").map(ProjectRef(buildLocation, _)) + "streaming-zeromq", "launcher", "unsafe").map(ProjectRef(buildLocation, _)) val optionallyEnabledProjects@Seq(yarn, yarnStable, java8Tests, sparkGangliaLgpl, sparkKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl", @@ -159,7 +159,7 @@ object SparkBuild extends PomBuild { // TODO: Add Sql to mima checks // TODO: remove launcher from this list after 1.3. allProjects.filterNot(x => Seq(spark, sql, hive, hiveThriftServer, catalyst, repl, - networkCommon, networkShuffle, networkYarn, launcher).contains(x)).foreach { + networkCommon, networkShuffle, networkYarn, launcher, unsafe).contains(x)).foreach { x => enable(MimaBuild.mimaSettings(sparkHome, x))(x) } @@ -496,6 +496,7 @@ object TestSettings { javaOptions in Test += "-Dspark.ui.enabled=false", javaOptions in Test += "-Dspark.ui.showConsoleProgress=false", javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true", + javaOptions in Test += "-Dspark.unsafe.exceptionOnMemoryLeak=true", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark") .map { case (k,v) => s"-D$k=$v" }.toSeq, diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index cc9a4cf8ba170..a57c0b3ae0d00 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -39,7 +39,8 @@ IntegerType, ByteType -__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors', 'DenseMatrix', 'Matrices'] +__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors', + 'Matrix', 'DenseMatrix', 'SparseMatrix', 'Matrices'] if sys.version_info[:2] == (2, 7): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 4759f5fe783ad..d9cbbc68b3bf0 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -237,7 +237,8 @@ def explain(self, extended=False): :param extended: boolean, default ``False``. If ``False``, prints only the physical plan. >>> df.explain() - PhysicalRDD [age#0,name#1], MapPartitionsRDD[...] at mapPartitions at SQLContext.scala:... + PhysicalRDD [age#0,name#1], MapPartitionsRDD[...] at applySchemaToPythonRDD at\ + NativeMethodAccessorImpl.java:... >>> df.explain(True) == Parsed Logical Plan == @@ -632,7 +633,8 @@ def __getattr__(self, name): [Row(age=2), Row(age=5)] """ if name not in self.columns: - raise AttributeError("No such column: %s" % name) + raise AttributeError( + "'%s' object has no attribute '%s'" % (self.__class__.__name__, name)) jc = self._jdf.apply(name) return Column(jc) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index bb47923f24b82..555c2fa5e7071 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -54,7 +54,7 @@ def _(col): 'upper': 'Converts a string expression to upper case.', 'lower': 'Converts a string expression to upper case.', 'sqrt': 'Computes the square root of the specified float value.', - 'abs': 'Computes the absolutle value.', + 'abs': 'Computes the absolute value.', 'max': 'Aggregate function: returns the maximum value of the expression in a group.', 'min': 'Aggregate function: returns the minimum value of the expression in a group.', @@ -75,6 +75,20 @@ def _(col): __all__.sort() +def approxCountDistinct(col, rsd=None): + """Returns a new :class:`Column` for approximate distinct count of ``col``. + + >>> df.agg(approxCountDistinct(df.age).alias('c')).collect() + [Row(c=2)] + """ + sc = SparkContext._active_spark_context + if rsd is None: + jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col)) + else: + jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col), rsd) + return Column(jc) + + def countDistinct(col, *cols): """Returns a new :class:`Column` for distinct count of ``col`` or ``cols``. @@ -89,18 +103,36 @@ def countDistinct(col, *cols): return Column(jc) -def approxCountDistinct(col, rsd=None): - """Returns a new :class:`Column` for approximate distinct count of ``col``. +def monotonicallyIncreasingId(): + """A column that generates monotonically increasing 64-bit integers. - >>> df.agg(approxCountDistinct(df.age).alias('c')).collect() - [Row(c=2)] + The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. + The current implementation puts the partition ID in the upper 31 bits, and the record number + within each partition in the lower 33 bits. The assumption is that the data frame has + less than 1 billion partitions, and each partition has less than 8 billion records. + + As an example, consider a [[DataFrame]] with two partitions, each with 3 records. + This expression would return the following IDs: + 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. + + >>> df0 = sc.parallelize(range(2), 2).mapPartitions(lambda x: [(1,), (2,), (3,)]).toDF(['col1']) + >>> df0.select(monotonicallyIncreasingId().alias('id')).collect() + [Row(id=0), Row(id=1), Row(id=2), Row(id=8589934592), Row(id=8589934593), Row(id=8589934594)] """ sc = SparkContext._active_spark_context - if rsd is None: - jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col)) - else: - jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col), rsd) - return Column(jc) + return Column(sc._jvm.functions.monotonicallyIncreasingId()) + + +def sparkPartitionId(): + """A column for partition ID of the Spark task. + + Note that this is indeterministic because it depends on data partitioning and task scheduling. + + >>> df.repartition(1).select(sparkPartitionId().alias("pid")).collect() + [Row(pid=0), Row(pid=0)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.sparkPartitionId()) class UserDefinedFunction(object): diff --git a/python/pyspark/sql/mathfunctions.py b/python/pyspark/sql/mathfunctions.py new file mode 100644 index 0000000000000..7dbcab8694293 --- /dev/null +++ b/python/pyspark/sql/mathfunctions.py @@ -0,0 +1,101 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +A collection of builtin math functions +""" + +from pyspark import SparkContext +from pyspark.sql.dataframe import Column + +__all__ = [] + + +def _create_unary_mathfunction(name, doc=""): + """ Create a unary mathfunction by name""" + def _(col): + sc = SparkContext._active_spark_context + jc = getattr(sc._jvm.mathfunctions, name)(col._jc if isinstance(col, Column) else col) + return Column(jc) + _.__name__ = name + _.__doc__ = doc + return _ + + +def _create_binary_mathfunction(name, doc=""): + """ Create a binary mathfunction by name""" + def _(col1, col2): + sc = SparkContext._active_spark_context + # users might write ints for simplicity. This would throw an error on the JVM side. + if type(col1) is int: + col1 = col1 * 1.0 + if type(col2) is int: + col2 = col2 * 1.0 + jc = getattr(sc._jvm.mathfunctions, name)(col1._jc if isinstance(col1, Column) else col1, + col2._jc if isinstance(col2, Column) else col2) + return Column(jc) + _.__name__ = name + _.__doc__ = doc + return _ + + +# math functions are found under another object therefore, they need to be handled separately +_mathfunctions = { + 'acos': 'Computes the cosine inverse of the given value; the returned angle is in the range' + + '0.0 through pi.', + 'asin': 'Computes the sine inverse of the given value; the returned angle is in the range' + + '-pi/2 through pi/2.', + 'atan': 'Computes the tangent inverse of the given value.', + 'cbrt': 'Computes the cube-root of the given value.', + 'ceil': 'Computes the ceiling of the given value.', + 'cos': 'Computes the cosine of the given value.', + 'cosh': 'Computes the hyperbolic cosine of the given value.', + 'exp': 'Computes the exponential of the given value.', + 'expm1': 'Computes the exponential of the given value minus one.', + 'floor': 'Computes the floor of the given value.', + 'log': 'Computes the natural logarithm of the given value.', + 'log10': 'Computes the logarithm of the given value in Base 10.', + 'log1p': 'Computes the natural logarithm of the given value plus one.', + 'rint': 'Returns the double value that is closest in value to the argument and' + + ' is equal to a mathematical integer.', + 'signum': 'Computes the signum of the given value.', + 'sin': 'Computes the sine of the given value.', + 'sinh': 'Computes the hyperbolic sine of the given value.', + 'tan': 'Computes the tangent of the given value.', + 'tanh': 'Computes the hyperbolic tangent of the given value.', + 'toDeg': 'Converts an angle measured in radians to an approximately equivalent angle ' + + 'measured in degrees.', + 'toRad': 'Converts an angle measured in degrees to an approximately equivalent angle ' + + 'measured in radians.' +} + +# math functions that take two arguments as input +_binary_mathfunctions = { + 'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' + + 'polar coordinates (r, theta).', + 'hypot': 'Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow.', + 'pow': 'Returns the value of the first argument raised to the power of the second argument.' +} + +for _name, _doc in _mathfunctions.items(): + globals()[_name] = _create_unary_mathfunction(_name, _doc) +for _name, _doc in _binary_mathfunctions.items(): + globals()[_name] = _create_binary_mathfunction(_name, _doc) +del _name, _doc +__all__ += _mathfunctions.keys() +__all__ += _binary_mathfunctions.keys() +__all__.sort() diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index fe43c374f1cb1..2ffd18ebd7c89 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -387,6 +387,35 @@ def test_aggregator(self): self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0]) self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0]) + def test_math_functions(self): + df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF() + from pyspark.sql import mathfunctions as functions + import math + + def get_values(l): + return [j[0] for j in l] + + def assert_close(a, b): + c = get_values(b) + diff = [abs(v - c[k]) < 1e-6 for k, v in enumerate(a)] + return sum(diff) == len(a) + assert_close([math.cos(i) for i in range(10)], + df.select(functions.cos(df.a)).collect()) + assert_close([math.cos(i) for i in range(10)], + df.select(functions.cos("a")).collect()) + assert_close([math.sin(i) for i in range(10)], + df.select(functions.sin(df.a)).collect()) + assert_close([math.sin(i) for i in range(10)], + df.select(functions.sin(df['a'])).collect()) + assert_close([math.pow(i, 2 * i) for i in range(10)], + df.select(functions.pow(df.a, df.b)).collect()) + assert_close([math.pow(i, 2) for i in range(10)], + df.select(functions.pow(df.a, 2)).collect()) + assert_close([math.pow(i, 2) for i in range(10)], + df.select(functions.pow(df.a, 2.0)).collect()) + assert_close([math.hypot(i, 2 * i) for i in range(10)], + df.select(functions.hypot(df.a, df.b)).collect()) + def test_save_and_load(self): df = self.df tmpPath = tempfile.mkdtemp() diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index 8d610d6569b4a..e278b29003f69 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -17,11 +17,12 @@ from py4j.java_gateway import Py4JJavaError +from pyspark.rdd import RDD from pyspark.storagelevel import StorageLevel from pyspark.serializers import PairDeserializer, NoOpSerializer from pyspark.streaming import DStream -__all__ = ['KafkaUtils', 'utf8_decoder'] +__all__ = ['Broker', 'KafkaUtils', 'OffsetRange', 'TopicAndPartition', 'utf8_decoder'] def utf8_decoder(s): @@ -67,7 +68,104 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, except Py4JJavaError as e: # TODO: use --jar once it also work on driver if 'ClassNotFoundException' in str(e.java_exception): - print(""" + KafkaUtils._printErrorMsg(ssc.sparkContext) + raise e + ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) + stream = DStream(jstream, ssc, ser) + return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) + + @staticmethod + def createDirectStream(ssc, topics, kafkaParams, fromOffsets={}, + keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): + """ + .. note:: Experimental + + Create an input stream that directly pulls messages from a Kafka Broker and specific offset. + + This is not a receiver based Kafka input stream, it directly pulls the message from Kafka + in each batch duration and processed without storing. + + This does not use Zookeeper to store offsets. The consumed offsets are tracked + by the stream itself. For interoperability with Kafka monitoring tools that depend on + Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application. + You can access the offsets used in each batch from the generated RDDs (see + + To recover from driver failures, you have to enable checkpointing in the StreamingContext. + The information on consumed offset can be recovered from the checkpoint. + See the programming guide for details (constraints, etc.). + + :param ssc: StreamingContext object. + :param topics: list of topic_name to consume. + :param kafkaParams: Additional params for Kafka. + :param fromOffsets: Per-topic/partition Kafka offsets defining the (inclusive) starting + point of the stream. + :param keyDecoder: A function used to decode key (default is utf8_decoder). + :param valueDecoder: A function used to decode value (default is utf8_decoder). + :return: A DStream object + """ + if not isinstance(topics, list): + raise TypeError("topics should be list") + if not isinstance(kafkaParams, dict): + raise TypeError("kafkaParams should be dict") + + try: + helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") + helper = helperClass.newInstance() + + jfromOffsets = dict([(k._jTopicAndPartition(helper), + v) for (k, v) in fromOffsets.items()]) + jstream = helper.createDirectStream(ssc._jssc, kafkaParams, set(topics), jfromOffsets) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + KafkaUtils._printErrorMsg(ssc.sparkContext) + raise e + + ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) + stream = DStream(jstream, ssc, ser) + return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) + + @staticmethod + def createRDD(sc, kafkaParams, offsetRanges, leaders={}, + keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): + """ + .. note:: Experimental + + Create a RDD from Kafka using offset ranges for each topic and partition. + :param sc: SparkContext object + :param kafkaParams: Additional params for Kafka + :param offsetRanges: list of offsetRange to specify topic:partition:[start, end) to consume + :param leaders: Kafka brokers for each TopicAndPartition in offsetRanges. May be an empty + map, in which case leaders will be looked up on the driver. + :param keyDecoder: A function used to decode key (default is utf8_decoder) + :param valueDecoder: A function used to decode value (default is utf8_decoder) + :return: A RDD object + """ + if not isinstance(kafkaParams, dict): + raise TypeError("kafkaParams should be dict") + if not isinstance(offsetRanges, list): + raise TypeError("offsetRanges should be list") + + try: + helperClass = sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper") + helper = helperClass.newInstance() + joffsetRanges = [o._jOffsetRange(helper) for o in offsetRanges] + jleaders = dict([(k._jTopicAndPartition(helper), + v._jBroker(helper)) for (k, v) in leaders.items()]) + jrdd = helper.createRDD(sc._jsc, kafkaParams, joffsetRanges, jleaders) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + KafkaUtils._printErrorMsg(sc) + raise e + + ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) + rdd = RDD(jrdd, sc, ser) + return rdd.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) + + @staticmethod + def _printErrorMsg(sc): + print(""" ________________________________________________________________________________________________ Spark Streaming's Kafka libraries not found in class path. Try one of the following. @@ -85,8 +183,63 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={}, ________________________________________________________________________________________________ -""" % (ssc.sparkContext.version, ssc.sparkContext.version)) - raise e - ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) - stream = DStream(jstream, ssc, ser) - return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1]))) +""" % (sc.version, sc.version)) + + +class OffsetRange(object): + """ + Represents a range of offsets from a single Kafka TopicAndPartition. + """ + + def __init__(self, topic, partition, fromOffset, untilOffset): + """ + Create a OffsetRange to represent range of offsets + :param topic: Kafka topic name. + :param partition: Kafka partition id. + :param fromOffset: Inclusive starting offset. + :param untilOffset: Exclusive ending offset. + """ + self._topic = topic + self._partition = partition + self._fromOffset = fromOffset + self._untilOffset = untilOffset + + def _jOffsetRange(self, helper): + return helper.createOffsetRange(self._topic, self._partition, self._fromOffset, + self._untilOffset) + + +class TopicAndPartition(object): + """ + Represents a specific top and partition for Kafka. + """ + + def __init__(self, topic, partition): + """ + Create a Python TopicAndPartition to map to the Java related object + :param topic: Kafka topic name. + :param partition: Kafka partition id. + """ + self._topic = topic + self._partition = partition + + def _jTopicAndPartition(self, helper): + return helper.createTopicAndPartition(self._topic, self._partition) + + +class Broker(object): + """ + Represent the host and port info for a Kafka broker. + """ + + def __init__(self, host, port): + """ + Create a Python Broker to map to the Java related object. + :param host: Broker's hostname. + :param port: Broker's port. + """ + self._host = host + self._port = port + + def _jBroker(self, helper): + return helper.createBroker(self._host, self._port) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 5fa1e5ef081ab..7c06c203455d9 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -21,6 +21,7 @@ import time import operator import tempfile +import random import struct from functools import reduce @@ -35,7 +36,7 @@ from pyspark.context import SparkConf, SparkContext, RDD from pyspark.streaming.context import StreamingContext -from pyspark.streaming.kafka import KafkaUtils +from pyspark.streaming.kafka import Broker, KafkaUtils, OffsetRange, TopicAndPartition class PySparkStreamingTestCase(unittest.TestCase): @@ -590,9 +591,27 @@ def tearDown(self): super(KafkaStreamTests, self).tearDown() + def _randomTopic(self): + return "topic-%d" % random.randint(0, 10000) + + def _validateStreamResult(self, sendData, stream): + result = {} + for i in chain.from_iterable(self._collect(stream.map(lambda x: x[1]), + sum(sendData.values()))): + result[i] = result.get(i, 0) + 1 + + self.assertEqual(sendData, result) + + def _validateRddResult(self, sendData, rdd): + result = {} + for i in rdd.map(lambda x: x[1]).collect(): + result[i] = result.get(i, 0) + 1 + + self.assertEqual(sendData, result) + def test_kafka_stream(self): """Test the Python Kafka stream API.""" - topic = "topic1" + topic = self._randomTopic() sendData = {"a": 3, "b": 5, "c": 10} self._kafkaTestUtils.createTopic(topic) @@ -601,13 +620,64 @@ def test_kafka_stream(self): stream = KafkaUtils.createStream(self.ssc, self._kafkaTestUtils.zkAddress(), "test-streaming-consumer", {topic: 1}, {"auto.offset.reset": "smallest"}) + self._validateStreamResult(sendData, stream) - result = {} - for i in chain.from_iterable(self._collect(stream.map(lambda x: x[1]), - sum(sendData.values()))): - result[i] = result.get(i, 0) + 1 + def test_kafka_direct_stream(self): + """Test the Python direct Kafka stream API.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress(), + "auto.offset.reset": "smallest"} - self.assertEqual(sendData, result) + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams) + self._validateStreamResult(sendData, stream) + + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_direct_stream_from_offset(self): + """Test the Python direct Kafka stream API with start offset specified.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + fromOffsets = {TopicAndPartition(topic, 0): long(0)} + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams, fromOffsets) + self._validateStreamResult(sendData, stream) + + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_rdd(self): + """Test the Python direct Kafka RDD API.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2} + offsetRanges = [OffsetRange(topic, 0, long(0), long(sum(sendData.values())))] + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges) + self._validateRddResult(sendData, rdd) + + @unittest.skipIf(sys.version >= "3", "long type not support") + def test_kafka_rdd_with_leaders(self): + """Test the Python direct Kafka RDD API with leaders.""" + topic = self._randomTopic() + sendData = {"a": 1, "b": 2, "c": 3} + offsetRanges = [OffsetRange(topic, 0, long(0), long(sum(sendData.values())))] + kafkaParams = {"metadata.broker.list": self._kafkaTestUtils.brokerAddress()} + address = self._kafkaTestUtils.brokerAddress().split(":") + leaders = {TopicAndPartition(topic, 0): Broker(address[0], int(address[1]))} + + self._kafkaTestUtils.createTopic(topic) + self._kafkaTestUtils.sendMessages(topic, sendData) + + rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges, leaders) + self._validateRddResult(sendData, rdd) if __name__ == "__main__": unittest.main() diff --git a/sbin/start-mesos-dispatcher.sh b/sbin/start-mesos-dispatcher.sh new file mode 100755 index 0000000000000..ef1fc573d5c65 --- /dev/null +++ b/sbin/start-mesos-dispatcher.sh @@ -0,0 +1,40 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Starts the Mesos Cluster Dispatcher on the machine this script is executed on. +# The Mesos Cluster Dispatcher is responsible for launching the Mesos framework and +# Rest server to handle driver requests for Mesos cluster mode. +# Only one cluster dispatcher is needed per Mesos cluster. + +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" + +. "$sbin/spark-config.sh" + +. "$SPARK_PREFIX/bin/load-spark-env.sh" + +if [ "$SPARK_MESOS_DISPATCHER_PORT" = "" ]; then + SPARK_MESOS_DISPATCHER_PORT=7077 +fi + +if [ "$SPARK_MESOS_DISPATCHER_HOST" = "" ]; then + SPARK_MESOS_DISPATCHER_HOST=`hostname` +fi + + +"$sbin"/spark-daemon.sh start org.apache.spark.deploy.mesos.MesosClusterDispatcher 1 --host $SPARK_MESOS_DISPATCHER_HOST --port $SPARK_MESOS_DISPATCHER_PORT "$@" diff --git a/sbin/start-shuffle-service.sh b/sbin/start-shuffle-service.sh new file mode 100755 index 0000000000000..4fddcf7f95d40 --- /dev/null +++ b/sbin/start-shuffle-service.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Starts the external shuffle server on the machine this script is executed on. +# +# Usage: start-shuffle-server.sh +# +# Use the SPARK_SHUFFLE_OPTS environment variable to set shuffle server configuration. +# + +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" + +. "$sbin/spark-config.sh" +. "$SPARK_PREFIX/bin/load-spark-env.sh" + +exec "$sbin"/spark-daemon.sh start org.apache.spark.deploy.ExternalShuffleService 1 diff --git a/sbin/stop-mesos-dispatcher.sh b/sbin/stop-mesos-dispatcher.sh new file mode 100755 index 0000000000000..cb65d95b5e524 --- /dev/null +++ b/sbin/stop-mesos-dispatcher.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Stop the Mesos Cluster dispatcher on the machine this script is executed on. + +sbin=`dirname "$0"` +sbin=`cd "$sbin"; pwd` + +. "$sbin/spark-config.sh" + +"$sbin"/spark-daemon.sh stop org.apache.spark.deploy.mesos.MesosClusterDispatcher 1 + diff --git a/sbin/stop-shuffle-service.sh b/sbin/stop-shuffle-service.sh new file mode 100755 index 0000000000000..4cb6891ae27fa --- /dev/null +++ b/sbin/stop-shuffle-service.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Stops the external shuffle service on the machine this script is executed on. + +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" + +"$sbin"/spark-daemon.sh stop org.apache.spark.deploy.ExternalShuffleService 1 diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 3dea2ee76542f..5c322d032d474 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -50,6 +50,11 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-unsafe_${scala.binary.version} + ${project.version} + org.scalacheck scalacheck_${scala.binary.version} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java new file mode 100644 index 0000000000000..299ff3728a6d9 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import java.util.Arrays; +import java.util.Iterator; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.map.BytesToBytesMap; +import org.apache.spark.unsafe.memory.MemoryLocation; +import org.apache.spark.unsafe.memory.TaskMemoryManager; + +/** + * Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width. + * + * This map supports a maximum of 2 billion keys. + */ +public final class UnsafeFixedWidthAggregationMap { + + /** + * An empty aggregation buffer, encoded in UnsafeRow format. When inserting a new key into the + * map, we copy this buffer and use it as the value. + */ + private final long[] emptyAggregationBuffer; + + private final StructType aggregationBufferSchema; + + private final StructType groupingKeySchema; + + /** + * Encodes grouping keys as UnsafeRows. + */ + private final UnsafeRowConverter groupingKeyToUnsafeRowConverter; + + /** + * A hashmap which maps from opaque bytearray keys to bytearray values. + */ + private final BytesToBytesMap map; + + /** + * Re-used pointer to the current aggregation buffer + */ + private final UnsafeRow currentAggregationBuffer = new UnsafeRow(); + + /** + * Scratch space that is used when encoding grouping keys into UnsafeRow format. + * + * By default, this is a 1MB array, but it will grow as necessary in case larger keys are + * encountered. + */ + private long[] groupingKeyConversionScratchSpace = new long[1024 / 8]; + + private final boolean enablePerfMetrics; + + /** + * @return true if UnsafeFixedWidthAggregationMap supports grouping keys with the given schema, + * false otherwise. + */ + public static boolean supportsGroupKeySchema(StructType schema) { + for (StructField field: schema.fields()) { + if (!UnsafeRow.readableFieldTypes.contains(field.dataType())) { + return false; + } + } + return true; + } + + /** + * @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given + * schema, false otherwise. + */ + public static boolean supportsAggregationBufferSchema(StructType schema) { + for (StructField field: schema.fields()) { + if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) { + return false; + } + } + return true; + } + + /** + * Create a new UnsafeFixedWidthAggregationMap. + * + * @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function) + * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion. + * @param groupingKeySchema the schema of the grouping key, used for row conversion. + * @param memoryManager the memory manager used to allocate our Unsafe memory structures. + * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing). + * @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact) + */ + public UnsafeFixedWidthAggregationMap( + Row emptyAggregationBuffer, + StructType aggregationBufferSchema, + StructType groupingKeySchema, + TaskMemoryManager memoryManager, + int initialCapacity, + boolean enablePerfMetrics) { + this.emptyAggregationBuffer = + convertToUnsafeRow(emptyAggregationBuffer, aggregationBufferSchema); + this.aggregationBufferSchema = aggregationBufferSchema; + this.groupingKeyToUnsafeRowConverter = new UnsafeRowConverter(groupingKeySchema); + this.groupingKeySchema = groupingKeySchema; + this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics); + this.enablePerfMetrics = enablePerfMetrics; + } + + /** + * Convert a Java object row into an UnsafeRow, allocating it into a new long array. + */ + private static long[] convertToUnsafeRow(Row javaRow, StructType schema) { + final UnsafeRowConverter converter = new UnsafeRowConverter(schema); + final long[] unsafeRow = new long[converter.getSizeRequirement(javaRow)]; + final long writtenLength = + converter.writeRow(javaRow, unsafeRow, PlatformDependent.LONG_ARRAY_OFFSET); + assert (writtenLength == unsafeRow.length): "Size requirement calculation was wrong!"; + return unsafeRow; + } + + /** + * Return the aggregation buffer for the current group. For efficiency, all calls to this method + * return the same object. + */ + public UnsafeRow getAggregationBuffer(Row groupingKey) { + final int groupingKeySize = groupingKeyToUnsafeRowConverter.getSizeRequirement(groupingKey); + // Make sure that the buffer is large enough to hold the key. If it's not, grow it: + if (groupingKeySize > groupingKeyConversionScratchSpace.length) { + // This new array will be initially zero, so there's no need to zero it out here + groupingKeyConversionScratchSpace = new long[groupingKeySize]; + } else { + // Zero out the buffer that's used to hold the current row. This is necessary in order + // to ensure that rows hash properly, since garbage data from the previous row could + // otherwise end up as padding in this row. As a performance optimization, we only zero out + // the portion of the buffer that we'll actually write to. + Arrays.fill(groupingKeyConversionScratchSpace, 0, groupingKeySize, 0); + } + final long actualGroupingKeySize = groupingKeyToUnsafeRowConverter.writeRow( + groupingKey, + groupingKeyConversionScratchSpace, + PlatformDependent.LONG_ARRAY_OFFSET); + assert (groupingKeySize == actualGroupingKeySize) : "Size requirement calculation was wrong!"; + + // Probe our map using the serialized key + final BytesToBytesMap.Location loc = map.lookup( + groupingKeyConversionScratchSpace, + PlatformDependent.LONG_ARRAY_OFFSET, + groupingKeySize); + if (!loc.isDefined()) { + // This is the first time that we've seen this grouping key, so we'll insert a copy of the + // empty aggregation buffer into the map: + loc.putNewKey( + groupingKeyConversionScratchSpace, + PlatformDependent.LONG_ARRAY_OFFSET, + groupingKeySize, + emptyAggregationBuffer, + PlatformDependent.LONG_ARRAY_OFFSET, + emptyAggregationBuffer.length + ); + } + + // Reset the pointer to point to the value that we just stored or looked up: + final MemoryLocation address = loc.getValueAddress(); + currentAggregationBuffer.pointTo( + address.getBaseObject(), + address.getBaseOffset(), + aggregationBufferSchema.length(), + aggregationBufferSchema + ); + return currentAggregationBuffer; + } + + /** + * Mutable pair object returned by {@link UnsafeFixedWidthAggregationMap#iterator()}. + */ + public static class MapEntry { + private MapEntry() { }; + public final UnsafeRow key = new UnsafeRow(); + public final UnsafeRow value = new UnsafeRow(); + } + + /** + * Returns an iterator over the keys and values in this map. + * + * For efficiency, each call returns the same object. + */ + public Iterator iterator() { + return new Iterator() { + + private final MapEntry entry = new MapEntry(); + private final Iterator mapLocationIterator = map.iterator(); + + @Override + public boolean hasNext() { + return mapLocationIterator.hasNext(); + } + + @Override + public MapEntry next() { + final BytesToBytesMap.Location loc = mapLocationIterator.next(); + final MemoryLocation keyAddress = loc.getKeyAddress(); + final MemoryLocation valueAddress = loc.getValueAddress(); + entry.key.pointTo( + keyAddress.getBaseObject(), + keyAddress.getBaseOffset(), + groupingKeySchema.length(), + groupingKeySchema + ); + entry.value.pointTo( + valueAddress.getBaseObject(), + valueAddress.getBaseOffset(), + aggregationBufferSchema.length(), + aggregationBufferSchema + ); + return entry; + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } + + /** + * Free the unsafe memory associated with this map. + */ + public void free() { + map.free(); + } + + @SuppressWarnings("UseOfSystemOutOrSystemErr") + public void printPerfMetrics() { + if (!enablePerfMetrics) { + throw new IllegalStateException("Perf metrics not enabled"); + } + System.out.println("Average probes per lookup: " + map.getAverageProbesPerLookup()); + System.out.println("Number of hash collisions: " + map.getNumHashCollisions()); + System.out.println("Time spent resizing (ns): " + map.getTimeSpentResizingNs()); + System.out.println("Total memory consumption (bytes): " + map.getTotalMemoryConsumption()); + } + +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java new file mode 100644 index 0000000000000..0a358ed408aa1 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -0,0 +1,435 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import scala.collection.Map; +import scala.collection.Seq; +import scala.collection.mutable.ArraySeq; + +import javax.annotation.Nullable; +import java.math.BigDecimal; +import java.sql.Date; +import java.util.*; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.types.DataType; +import static org.apache.spark.sql.types.DataTypes.*; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.UTF8String; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.bitset.BitSetMethods; + +/** + * An Unsafe implementation of Row which is backed by raw memory instead of Java objects. + * + * Each tuple has three parts: [null bit set] [values] [variable length portion] + * + * The bit set is used for null tracking and is aligned to 8-byte word boundaries. It stores + * one bit per field. + * + * In the `values` region, we store one 8-byte word per field. For fields that hold fixed-length + * primitive types, such as long, double, or int, we store the value directly in the word. For + * fields with non-primitive or variable-length values, we store a relative offset (w.r.t. the + * base address of the row) that points to the beginning of the variable-length field. + * + * Instances of `UnsafeRow` act as pointers to row data stored in this format. + */ +public final class UnsafeRow implements MutableRow { + + private Object baseObject; + private long baseOffset; + + Object getBaseObject() { return baseObject; } + long getBaseOffset() { return baseOffset; } + + /** The number of fields in this row, used for calculating the bitset width (and in assertions) */ + private int numFields; + + /** The width of the null tracking bit set, in bytes */ + private int bitSetWidthInBytes; + /** + * This optional schema is required if you want to call generic get() and set() methods on + * this UnsafeRow, but is optional if callers will only use type-specific getTYPE() and setTYPE() + * methods. This should be removed after the planned InternalRow / Row split; right now, it's only + * needed by the generic get() method, which is only called internally by code that accesses + * UTF8String-typed columns. + */ + @Nullable + private StructType schema; + + private long getFieldOffset(int ordinal) { + return baseOffset + bitSetWidthInBytes + ordinal * 8L; + } + + public static int calculateBitSetWidthInBytes(int numFields) { + return ((numFields / 64) + (numFields % 64 == 0 ? 0 : 1)) * 8; + } + + /** + * Field types that can be updated in place in UnsafeRows (e.g. we support set() for these types) + */ + public static final Set settableFieldTypes; + + /** + * Fields types can be read(but not set (e.g. set() will throw UnsupportedOperationException). + */ + public static final Set readableFieldTypes; + + static { + settableFieldTypes = Collections.unmodifiableSet( + new HashSet( + Arrays.asList(new DataType[] { + NullType, + BooleanType, + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType + }))); + + // We support get() on a superset of the types for which we support set(): + final Set _readableFieldTypes = new HashSet( + Arrays.asList(new DataType[]{ + StringType + })); + _readableFieldTypes.addAll(settableFieldTypes); + readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes); + } + + /** + * Construct a new UnsafeRow. The resulting row won't be usable until `pointTo()` has been called, + * since the value returned by this constructor is equivalent to a null pointer. + */ + public UnsafeRow() { } + + /** + * Update this UnsafeRow to point to different backing data. + * + * @param baseObject the base object + * @param baseOffset the offset within the base object + * @param numFields the number of fields in this row + * @param schema an optional schema; this is necessary if you want to call generic get() or set() + * methods on this row, but is optional if the caller will only use type-specific + * getTYPE() and setTYPE() methods. + */ + public void pointTo( + Object baseObject, + long baseOffset, + int numFields, + @Nullable StructType schema) { + assert numFields >= 0 : "numFields should >= 0"; + assert schema == null || schema.fields().length == numFields; + this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields); + this.baseObject = baseObject; + this.baseOffset = baseOffset; + this.numFields = numFields; + this.schema = schema; + } + + private void assertIndexIsValid(int index) { + assert index >= 0 : "index (" + index + ") should >= 0"; + assert index < numFields : "index (" + index + ") should <= " + numFields; + } + + @Override + public void setNullAt(int i) { + assertIndexIsValid(i); + BitSetMethods.set(baseObject, baseOffset, i); + // To preserve row equality, zero out the value when setting the column to null. + // Since this row does does not currently support updates to variable-length values, we don't + // have to worry about zeroing out that data. + PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(i), 0); + } + + private void setNotNullAt(int i) { + assertIndexIsValid(i); + BitSetMethods.unset(baseObject, baseOffset, i); + } + + @Override + public void update(int ordinal, Object value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setInt(int ordinal, int value) { + assertIndexIsValid(ordinal); + setNotNullAt(ordinal); + PlatformDependent.UNSAFE.putInt(baseObject, getFieldOffset(ordinal), value); + } + + @Override + public void setLong(int ordinal, long value) { + assertIndexIsValid(ordinal); + setNotNullAt(ordinal); + PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(ordinal), value); + } + + @Override + public void setDouble(int ordinal, double value) { + assertIndexIsValid(ordinal); + setNotNullAt(ordinal); + PlatformDependent.UNSAFE.putDouble(baseObject, getFieldOffset(ordinal), value); + } + + @Override + public void setBoolean(int ordinal, boolean value) { + assertIndexIsValid(ordinal); + setNotNullAt(ordinal); + PlatformDependent.UNSAFE.putBoolean(baseObject, getFieldOffset(ordinal), value); + } + + @Override + public void setShort(int ordinal, short value) { + assertIndexIsValid(ordinal); + setNotNullAt(ordinal); + PlatformDependent.UNSAFE.putShort(baseObject, getFieldOffset(ordinal), value); + } + + @Override + public void setByte(int ordinal, byte value) { + assertIndexIsValid(ordinal); + setNotNullAt(ordinal); + PlatformDependent.UNSAFE.putByte(baseObject, getFieldOffset(ordinal), value); + } + + @Override + public void setFloat(int ordinal, float value) { + assertIndexIsValid(ordinal); + setNotNullAt(ordinal); + PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value); + } + + @Override + public void setString(int ordinal, String value) { + throw new UnsupportedOperationException(); + } + + @Override + public int size() { + return numFields; + } + + @Override + public int length() { + return size(); + } + + @Override + public StructType schema() { + return schema; + } + + @Override + public Object apply(int i) { + return get(i); + } + + @Override + public Object get(int i) { + assertIndexIsValid(i); + assert (schema != null) : "Schema must be defined when calling generic get() method"; + final DataType dataType = schema.fields()[i].dataType(); + // UnsafeRow is only designed to be invoked by internal code, which only invokes this generic + // get() method when trying to access UTF8String-typed columns. If we refactor the codebase to + // separate the internal and external row interfaces, then internal code can fetch strings via + // a new getUTF8String() method and we'll be able to remove this method. + if (isNullAt(i)) { + return null; + } else if (dataType == StringType) { + return getUTF8String(i); + } else { + throw new UnsupportedOperationException(); + } + } + + @Override + public boolean isNullAt(int i) { + assertIndexIsValid(i); + return BitSetMethods.isSet(baseObject, baseOffset, i); + } + + @Override + public boolean getBoolean(int i) { + assertIndexIsValid(i); + return PlatformDependent.UNSAFE.getBoolean(baseObject, getFieldOffset(i)); + } + + @Override + public byte getByte(int i) { + assertIndexIsValid(i); + return PlatformDependent.UNSAFE.getByte(baseObject, getFieldOffset(i)); + } + + @Override + public short getShort(int i) { + assertIndexIsValid(i); + return PlatformDependent.UNSAFE.getShort(baseObject, getFieldOffset(i)); + } + + @Override + public int getInt(int i) { + assertIndexIsValid(i); + return PlatformDependent.UNSAFE.getInt(baseObject, getFieldOffset(i)); + } + + @Override + public long getLong(int i) { + assertIndexIsValid(i); + return PlatformDependent.UNSAFE.getLong(baseObject, getFieldOffset(i)); + } + + @Override + public float getFloat(int i) { + assertIndexIsValid(i); + if (isNullAt(i)) { + return Float.NaN; + } else { + return PlatformDependent.UNSAFE.getFloat(baseObject, getFieldOffset(i)); + } + } + + @Override + public double getDouble(int i) { + assertIndexIsValid(i); + if (isNullAt(i)) { + return Float.NaN; + } else { + return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(i)); + } + } + + public UTF8String getUTF8String(int i) { + assertIndexIsValid(i); + final UTF8String str = new UTF8String(); + final long offsetToStringSize = getLong(i); + final int stringSizeInBytes = + (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offsetToStringSize); + final byte[] strBytes = new byte[stringSizeInBytes]; + PlatformDependent.copyMemory( + baseObject, + baseOffset + offsetToStringSize + 8, // The `+ 8` is to skip past the size to get the data + strBytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + stringSizeInBytes + ); + str.set(strBytes); + return str; + } + + @Override + public String getString(int i) { + return getUTF8String(i).toString(); + } + + @Override + public BigDecimal getDecimal(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public Date getDate(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public Seq getSeq(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public List getList(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public Map getMap(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public scala.collection.immutable.Map getValuesMap(Seq fieldNames) { + throw new UnsupportedOperationException(); + } + + @Override + public java.util.Map getJavaMap(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public Row getStruct(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public T getAs(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public T getAs(String fieldName) { + throw new UnsupportedOperationException(); + } + + @Override + public int fieldIndex(String name) { + throw new UnsupportedOperationException(); + } + + @Override + public Row copy() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean anyNull() { + return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes); + } + + @Override + public Seq toSeq() { + final ArraySeq values = new ArraySeq(numFields); + for (int fieldNumber = 0; fieldNumber < numFields; fieldNumber++) { + values.update(fieldNumber, get(fieldNumber)); + } + return values; + } + + @Override + public String toString() { + return mkString("[", ",", "]"); + } + + @Override + public String mkString() { + return toSeq().mkString(); + } + + @Override + public String mkString(String sep) { + return toSeq().mkString(sep); + } + + @Override + public String mkString(String start, String sep, String end) { + return toSeq().mkString(start, sep, end); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 35c7f00d4e42a..73c9a1c7afdad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -79,6 +79,7 @@ trait HiveTypeCoercion { CaseWhenCoercion :: Division :: PropagateTypes :: + ExpectedInputConversion :: Nil /** @@ -643,4 +644,22 @@ trait HiveTypeCoercion { } } + /** + * Casts types according to the expected input types for Expressions that have the trait + * `ExpectsInputTypes`. + */ + object ExpectedInputConversion extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + // Skip nodes who's children have not been resolved yet. + case e if !e.childrenResolved => e + + case e: ExpectsInputTypes if e.children.map(_.dataType) != e.expectedChildTypes => + val newC = (e.children, e.children.map(_.dataType), e.expectedChildTypes).zipped.map { + case (child, actual, expected) => + if (actual == expected) child else Cast(child, expected) + } + e.withNewChildren(newC) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 4e3bbc06a5b4c..1d71c1b4b0c7c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -109,3 +109,13 @@ case class GroupExpression(children: Seq[Expression]) extends Expression { override def foldable: Boolean = false override def dataType: DataType = throw new UnsupportedOperationException } + +/** + * Expressions that require a specific `DataType` as input should implement this trait + * so that the proper type conversions can be performed in the analyzer. + */ +trait ExpectsInputTypes { + + def expectedChildTypes: Seq[DataType] + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala new file mode 100644 index 0000000000000..5b2c8572784bd --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.array.ByteArrayMethods + +/** + * Converts Rows into UnsafeRow format. This class is NOT thread-safe. + * + * @param fieldTypes the data types of the row's columns. + */ +class UnsafeRowConverter(fieldTypes: Array[DataType]) { + + def this(schema: StructType) { + this(schema.fields.map(_.dataType)) + } + + /** Re-used pointer to the unsafe row being written */ + private[this] val unsafeRow = new UnsafeRow() + + /** Functions for encoding each column */ + private[this] val writers: Array[UnsafeColumnWriter] = { + fieldTypes.map(t => UnsafeColumnWriter.forType(t)) + } + + /** The size, in bytes, of the fixed-length portion of the row, including the null bitmap */ + private[this] val fixedLengthSize: Int = + (8 * fieldTypes.length) + UnsafeRow.calculateBitSetWidthInBytes(fieldTypes.length) + + /** + * Compute the amount of space, in bytes, required to encode the given row. + */ + def getSizeRequirement(row: Row): Int = { + var fieldNumber = 0 + var variableLengthFieldSize: Int = 0 + while (fieldNumber < writers.length) { + if (!row.isNullAt(fieldNumber)) { + variableLengthFieldSize += writers(fieldNumber).getSize(row, fieldNumber) + } + fieldNumber += 1 + } + fixedLengthSize + variableLengthFieldSize + } + + /** + * Convert the given row into UnsafeRow format. + * + * @param row the row to convert + * @param baseObject the base object of the destination address + * @param baseOffset the base offset of the destination address + * @return the number of bytes written. This should be equal to `getSizeRequirement(row)`. + */ + def writeRow(row: Row, baseObject: Object, baseOffset: Long): Long = { + unsafeRow.pointTo(baseObject, baseOffset, writers.length, null) + var fieldNumber = 0 + var appendCursor: Int = fixedLengthSize + while (fieldNumber < writers.length) { + if (row.isNullAt(fieldNumber)) { + unsafeRow.setNullAt(fieldNumber) + } else { + appendCursor += writers(fieldNumber).write(row, unsafeRow, fieldNumber, appendCursor) + } + fieldNumber += 1 + } + appendCursor + } + +} + +/** + * Function for writing a column into an UnsafeRow. + */ +private abstract class UnsafeColumnWriter { + /** + * Write a value into an UnsafeRow. + * + * @param source the row being converted + * @param target a pointer to the converted unsafe row + * @param column the column to write + * @param appendCursor the offset from the start of the unsafe row to the end of the row; + * used for calculating where variable-length data should be written + * @return the number of variable-length bytes written + */ + def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int + + /** + * Return the number of bytes that are needed to write this variable-length value. + */ + def getSize(source: Row, column: Int): Int +} + +private object UnsafeColumnWriter { + + def forType(dataType: DataType): UnsafeColumnWriter = { + dataType match { + case NullType => NullUnsafeColumnWriter + case BooleanType => BooleanUnsafeColumnWriter + case ByteType => ByteUnsafeColumnWriter + case ShortType => ShortUnsafeColumnWriter + case IntegerType => IntUnsafeColumnWriter + case LongType => LongUnsafeColumnWriter + case FloatType => FloatUnsafeColumnWriter + case DoubleType => DoubleUnsafeColumnWriter + case StringType => StringUnsafeColumnWriter + case t => + throw new UnsupportedOperationException(s"Do not know how to write columns of type $t") + } + } +} + +// ------------------------------------------------------------------------------------------------ + +private object NullUnsafeColumnWriter extends NullUnsafeColumnWriter +private object BooleanUnsafeColumnWriter extends BooleanUnsafeColumnWriter +private object ByteUnsafeColumnWriter extends ByteUnsafeColumnWriter +private object ShortUnsafeColumnWriter extends ShortUnsafeColumnWriter +private object IntUnsafeColumnWriter extends IntUnsafeColumnWriter +private object LongUnsafeColumnWriter extends LongUnsafeColumnWriter +private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter +private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter +private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter + +private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter { + // Primitives don't write to the variable-length region: + def getSize(sourceRow: Row, column: Int): Int = 0 +} + +private class NullUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + target.setNullAt(column) + 0 + } +} + +private class BooleanUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + target.setBoolean(column, source.getBoolean(column)) + 0 + } +} + +private class ByteUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + target.setByte(column, source.getByte(column)) + 0 + } +} + +private class ShortUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + target.setShort(column, source.getShort(column)) + 0 + } +} + +private class IntUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + target.setInt(column, source.getInt(column)) + 0 + } +} + +private class LongUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + target.setLong(column, source.getLong(column)) + 0 + } +} + +private class FloatUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + target.setFloat(column, source.getFloat(column)) + 0 + } +} + +private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + target.setDouble(column, source.getDouble(column)) + 0 + } +} + +private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter { + def getSize(source: Row, column: Int): Int = { + val numBytes = source.get(column).asInstanceOf[UTF8String].getBytes.length + 8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) + } + + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + val value = source.get(column).asInstanceOf[UTF8String] + val baseObject = target.getBaseObject + val baseOffset = target.getBaseOffset + val numBytes = value.getBytes.length + PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + appendCursor, numBytes) + PlatformDependent.copyMemory( + value.getBytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + baseObject, + baseOffset + appendCursor + 8, + numBytes + ) + target.setLong(column, appendCursor) + 8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala new file mode 100644 index 0000000000000..fcc06d3aa1036 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.mathfuncs + +import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, BinaryExpression, Expression, Row} +import org.apache.spark.sql.types._ + +/** + * A binary expression specifically for math functions that take two `Double`s as input and returns + * a `Double`. + * @param f The math function. + * @param name The short name of the function + */ +abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) + extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product => + type EvaluatedType = Any + override def symbol: String = null + override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType) + + override def nullable: Boolean = left.nullable || right.nullable + override def toString: String = s"$name($left, $right)" + + override lazy val resolved = + left.resolved && right.resolved && + left.dataType == right.dataType && + !DecimalType.isFixed(left.dataType) + + override def dataType: DataType = { + if (!resolved) { + throw new UnresolvedException(this, + s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}") + } + left.dataType + } + + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + val result = f(evalE1.asInstanceOf[Double], evalE2.asInstanceOf[Double]) + if (result.isNaN) null else result + } + } + } +} + +case class Atan2( + left: Expression, + right: Expression) extends BinaryMathExpression(math.atan2, "ATAN2") { + override def eval(input: Row): Any = { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + // With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0 + val result = math.atan2(evalE1.asInstanceOf[Double] + 0.0, + evalE2.asInstanceOf[Double] + 0.0) + if (result.isNaN) null else result + } + } + } +} + +case class Hypot( + left: Expression, + right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT") + +case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala new file mode 100644 index 0000000000000..dc68469e060cb --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.mathfuncs + +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, Row, UnaryExpression} +import org.apache.spark.sql.types._ + +/** + * A unary expression specifically for math functions. Math Functions expect a specific type of + * input format, therefore these functions extend `ExpectsInputTypes`. + * @param name The short name of the function + */ +abstract class MathematicalExpression(f: Double => Double, name: String) + extends UnaryExpression with Serializable with ExpectsInputTypes { + self: Product => + type EvaluatedType = Any + + override def expectedChildTypes: Seq[DataType] = Seq(DoubleType) + override def dataType: DataType = DoubleType + override def foldable: Boolean = child.foldable + override def nullable: Boolean = true + override def toString: String = s"$name($child)" + + override def eval(input: Row): Any = { + val evalE = child.eval(input) + if (evalE == null) { + null + } else { + val result = f(evalE.asInstanceOf[Double]) + if (result.isNaN) null else result + } + } +} + +case class Acos(child: Expression) extends MathematicalExpression(math.acos, "ACOS") + +case class Asin(child: Expression) extends MathematicalExpression(math.asin, "ASIN") + +case class Atan(child: Expression) extends MathematicalExpression(math.atan, "ATAN") + +case class Cbrt(child: Expression) extends MathematicalExpression(math.cbrt, "CBRT") + +case class Ceil(child: Expression) extends MathematicalExpression(math.ceil, "CEIL") + +case class Cos(child: Expression) extends MathematicalExpression(math.cos, "COS") + +case class Cosh(child: Expression) extends MathematicalExpression(math.cosh, "COSH") + +case class Exp(child: Expression) extends MathematicalExpression(math.exp, "EXP") + +case class Expm1(child: Expression) extends MathematicalExpression(math.expm1, "EXPM1") + +case class Floor(child: Expression) extends MathematicalExpression(math.floor, "FLOOR") + +case class Log(child: Expression) extends MathematicalExpression(math.log, "LOG") + +case class Log10(child: Expression) extends MathematicalExpression(math.log10, "LOG10") + +case class Log1p(child: Expression) extends MathematicalExpression(math.log1p, "LOG1P") + +case class Rint(child: Expression) extends MathematicalExpression(math.rint, "ROUND") + +case class Signum(child: Expression) extends MathematicalExpression(math.signum, "SIGNUM") + +case class Sin(child: Expression) extends MathematicalExpression(math.sin, "SIN") + +case class Sinh(child: Expression) extends MathematicalExpression(math.sinh, "SINH") + +case class Tan(child: Expression) extends MathematicalExpression(math.tan, "TAN") + +case class Tanh(child: Expression) extends MathematicalExpression(math.tanh, "TANH") + +case class ToDegrees(child: Expression) + extends MathematicalExpression(math.toDegrees, "DEGREES") + +case class ToRadians(child: Expression) + extends MathematicalExpression(math.toRadians, "RADIANS") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 9c8c643f7d17a..4574934d910db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -92,7 +92,7 @@ object PhysicalOperation extends PredicateHelper { } def collectAliases(fields: Seq[Expression]): Map[Attribute, Expression] = fields.collect { - case a @ Alias(child, _) => a.toAttribute.asInstanceOf[Attribute] -> child + case a @ Alias(child, _) => a.toAttribute -> child }.toMap def substitute(aliases: Map[Attribute, Expression])(expr: Expression): Expression = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index bbc94a7ab3398..608e272da7784 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -310,6 +310,17 @@ case class Distinct(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output } +/** + * Return a new RDD that has exactly `numPartitions` partitions. Differs from + * [[RepartitionByExpression]] as this method is called directly by DataFrame's, because the user + * asked for `coalesce` or `repartition`. [[RepartitionByExpression]] is used when the consumer + * of the output requires some specific ordering or distribution of the data. + */ +case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan) + extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + /** * A relation with one row. This is used in "SELECT ..." without a from clause. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala index e737418d9c3bc..63df2c1ee72ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/partitioning.scala @@ -32,5 +32,11 @@ abstract class RedistributeData extends UnaryNode { case class SortPartitions(sortExpressions: Seq[SortOrder], child: LogicalPlan) extends RedistributeData -case class Repartition(partitionExpressions: Seq[Expression], child: LogicalPlan) +/** + * This method repartitions data using [[Expression]]s, and receives information about the + * number of partitions during execution. Used when a specific ordering or distribution is + * expected by the consumer of the query result. Use [[Repartition]] for RDD-like + * `coalesce` and `repartition`. + */ +case class RepartitionByExpression(partitionExpressions: Seq[Expression], child: LogicalPlan) extends RedistributeData diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala index 5163f05879e42..04f3379afb38d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala @@ -108,7 +108,7 @@ private[sql] object DataTypeParser { override val lexical = new SqlLexical } - def apply(dataTypeString: String): DataType = dataTypeParser.toDataType(dataTypeString) + def parse(dataTypeString: String): DataType = dataTypeParser.toDataType(dataTypeString) } /** The exception thrown from the [[DataTypeParser]]. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 76298f03c94ae..fa71001c9336e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -28,6 +28,7 @@ import org.scalatest.Matchers._ import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.mathfuncs._ import org.apache.spark.sql.types._ @@ -1152,6 +1153,158 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { checkEvaluation(c1 ^ c2, 3, row) checkEvaluation(~c1, -2, row) } + + /** + * Used for testing math functions for DataFrames. + * @param c The DataFrame function + * @param f The functions in scala.math + * @param domain The set of values to run the function with + * @param expectNull Whether the given values should return null or not + * @tparam T Generic type for primitives + */ + def unaryMathFunctionEvaluation[@specialized(Int, Double, Float, Long) T]( + c: Expression => Expression, + f: T => T, + domain: Iterable[T] = (-20 to 20).map(_ * 0.1), + expectNull: Boolean = false): Unit = { + if (expectNull) { + domain.foreach { value => + checkEvaluation(c(Literal(value)), null, EmptyRow) + } + } else { + domain.foreach { value => + checkEvaluation(c(Literal(value)), f(value), EmptyRow) + } + } + checkEvaluation(c(Literal.create(null, DoubleType)), null, create_row(null)) + } + + test("sin") { + unaryMathFunctionEvaluation(Sin, math.sin) + } + + test("asin") { + unaryMathFunctionEvaluation(Asin, math.asin, (-10 to 10).map(_ * 0.1)) + unaryMathFunctionEvaluation(Asin, math.asin, (11 to 20).map(_ * 0.1), true) + } + + test("sinh") { + unaryMathFunctionEvaluation(Sinh, math.sinh) + } + + test("cos") { + unaryMathFunctionEvaluation(Cos, math.cos) + } + + test("acos") { + unaryMathFunctionEvaluation(Acos, math.acos, (-10 to 10).map(_ * 0.1)) + unaryMathFunctionEvaluation(Acos, math.acos, (11 to 20).map(_ * 0.1), true) + } + + test("cosh") { + unaryMathFunctionEvaluation(Cosh, math.cosh) + } + + test("tan") { + unaryMathFunctionEvaluation(Tan, math.tan) + } + + test("atan") { + unaryMathFunctionEvaluation(Atan, math.atan) + } + + test("tanh") { + unaryMathFunctionEvaluation(Tanh, math.tanh) + } + + test("toDeg") { + unaryMathFunctionEvaluation(ToDegrees, math.toDegrees) + } + + test("toRad") { + unaryMathFunctionEvaluation(ToRadians, math.toRadians) + } + + test("cbrt") { + unaryMathFunctionEvaluation(Cbrt, math.cbrt) + } + + test("ceil") { + unaryMathFunctionEvaluation(Ceil, math.ceil) + } + + test("floor") { + unaryMathFunctionEvaluation(Floor, math.floor) + } + + test("rint") { + unaryMathFunctionEvaluation(Rint, math.rint) + } + + test("exp") { + unaryMathFunctionEvaluation(Exp, math.exp) + } + + test("expm1") { + unaryMathFunctionEvaluation(Expm1, math.expm1) + } + + test("signum") { + unaryMathFunctionEvaluation[Double](Signum, math.signum) + } + + test("log") { + unaryMathFunctionEvaluation(Log, math.log, (0 to 20).map(_ * 0.1)) + unaryMathFunctionEvaluation(Log, math.log, (-5 to -1).map(_ * 0.1), true) + } + + test("log10") { + unaryMathFunctionEvaluation(Log10, math.log10, (0 to 20).map(_ * 0.1)) + unaryMathFunctionEvaluation(Log10, math.log10, (-5 to -1).map(_ * 0.1), true) + } + + test("log1p") { + unaryMathFunctionEvaluation(Log1p, math.log1p, (-1 to 20).map(_ * 0.1)) + unaryMathFunctionEvaluation(Log1p, math.log1p, (-10 to -2).map(_ * 1.0), true) + } + + /** + * Used for testing math functions for DataFrames. + * @param c The DataFrame function + * @param f The functions in scala.math + * @param domain The set of values to run the function with + */ + def binaryMathFunctionEvaluation( + c: (Expression, Expression) => Expression, + f: (Double, Double) => Double, + domain: Iterable[(Double, Double)] = (-20 to 20).map(v => (v * 0.1, v * -0.1)), + expectNull: Boolean = false): Unit = { + if (expectNull) { + domain.foreach { case (v1, v2) => + checkEvaluation(c(v1, v2), null, create_row(null)) + } + } else { + domain.foreach { case (v1, v2) => + checkEvaluation(c(v1, v2), f(v1 + 0.0, v2 + 0.0), EmptyRow) + checkEvaluation(c(v2, v1), f(v2 + 0.0, v1 + 0.0), EmptyRow) + } + } + checkEvaluation(c(Literal.create(null, DoubleType), 1.0), null, create_row(null)) + checkEvaluation(c(1.0, Literal.create(null, DoubleType)), null, create_row(null)) + } + + test("pow") { + binaryMathFunctionEvaluation(Pow, math.pow, (-5 to 5).map(v => (v * 1.0, v * 1.0))) + binaryMathFunctionEvaluation(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), true) + } + + test("hypot") { + binaryMathFunctionEvaluation(Hypot, math.hypot) + } + + test("atan2") { + binaryMathFunctionEvaluation(Atan2, math.atan2) + } } // TODO: Make the tests work with codegen. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala new file mode 100644 index 0000000000000..7a19e511eb8b5 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import scala.collection.JavaConverters._ +import scala.util.Random + +import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, TaskMemoryManager, MemoryAllocator} +import org.scalatest.{BeforeAndAfterEach, FunSuite, Matchers} + +import org.apache.spark.sql.types._ + +class UnsafeFixedWidthAggregationMapSuite extends FunSuite with Matchers with BeforeAndAfterEach { + + import UnsafeFixedWidthAggregationMap._ + + private val groupKeySchema = StructType(StructField("product", StringType) :: Nil) + private val aggBufferSchema = StructType(StructField("salePrice", IntegerType) :: Nil) + private def emptyAggregationBuffer: Row = new GenericRow(Array[Any](0)) + + private var memoryManager: TaskMemoryManager = null + + override def beforeEach(): Unit = { + memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) + } + + override def afterEach(): Unit = { + if (memoryManager != null) { + memoryManager.cleanUpAllAllocatedMemory() + memoryManager = null + } + } + + test("supported schemas") { + assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil))) + assert(supportsGroupKeySchema(StructType(StructField("x", StringType) :: Nil))) + + assert( + !supportsAggregationBufferSchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) + assert( + !supportsGroupKeySchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) + } + + test("empty map") { + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, + memoryManager, + 1024, // initial capacity + false // disable perf metrics + ) + assert(!map.iterator().hasNext) + map.free() + } + + test("updating values for a single key") { + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, + memoryManager, + 1024, // initial capacity + false // disable perf metrics + ) + val groupKey = new GenericRow(Array[Any](UTF8String("cats"))) + + // Looking up a key stores a zero-entry in the map (like Python Counters or DefaultDicts) + map.getAggregationBuffer(groupKey) + val iter = map.iterator() + val entry = iter.next() + assert(!iter.hasNext) + entry.key.getString(0) should be ("cats") + entry.value.getInt(0) should be (0) + + // Modifications to rows retrieved from the map should update the values in the map + entry.value.setInt(0, 42) + map.getAggregationBuffer(groupKey).getInt(0) should be (42) + + map.free() + } + + test("inserting large random keys") { + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, + memoryManager, + 128, // initial capacity + false // disable perf metrics + ) + val rand = new Random(42) + val groupKeys: Set[String] = Seq.fill(512)(rand.nextString(1024)).toSet + groupKeys.foreach { keyString => + map.getAggregationBuffer(new GenericRow(Array[Any](UTF8String(keyString)))) + } + val seenKeys: Set[String] = map.iterator().asScala.map { entry => + entry.key.getString(0) + }.toSet + seenKeys.size should be (groupKeys.size) + seenKeys should be (groupKeys) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala new file mode 100644 index 0000000000000..3a60c7fd32675 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.util.Arrays + +import org.scalatest.{FunSuite, Matchers} + +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.array.ByteArrayMethods + +class UnsafeRowConverterSuite extends FunSuite with Matchers { + + test("basic conversion with only primitive types") { + val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType) + val converter = new UnsafeRowConverter(fieldTypes) + + val row = new SpecificMutableRow(fieldTypes) + row.setLong(0, 0) + row.setLong(1, 1) + row.setInt(2, 2) + + val sizeRequired: Int = converter.getSizeRequirement(row) + sizeRequired should be (8 + (3 * 8)) + val buffer: Array[Long] = new Array[Long](sizeRequired / 8) + val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET) + numBytesWritten should be (sizeRequired) + + val unsafeRow = new UnsafeRow() + unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + unsafeRow.getLong(0) should be (0) + unsafeRow.getLong(1) should be (1) + unsafeRow.getInt(2) should be (2) + } + + test("basic conversion with primitive and string types") { + val fieldTypes: Array[DataType] = Array(LongType, StringType, StringType) + val converter = new UnsafeRowConverter(fieldTypes) + + val row = new SpecificMutableRow(fieldTypes) + row.setLong(0, 0) + row.setString(1, "Hello") + row.setString(2, "World") + + val sizeRequired: Int = converter.getSizeRequirement(row) + sizeRequired should be (8 + (8 * 3) + + ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length + 8) + + ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length + 8)) + val buffer: Array[Long] = new Array[Long](sizeRequired / 8) + val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET) + numBytesWritten should be (sizeRequired) + + val unsafeRow = new UnsafeRow() + unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + unsafeRow.getLong(0) should be (0) + unsafeRow.getString(1) should be ("Hello") + unsafeRow.getString(2) should be ("World") + } + + test("null handling") { + val fieldTypes: Array[DataType] = Array( + NullType, + BooleanType, + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType) + val converter = new UnsafeRowConverter(fieldTypes) + + val rowWithAllNullColumns: Row = { + val r = new SpecificMutableRow(fieldTypes) + for (i <- 0 to fieldTypes.length - 1) { + r.setNullAt(i) + } + r + } + + val sizeRequired: Int = converter.getSizeRequirement(rowWithAllNullColumns) + val createdFromNullBuffer: Array[Long] = new Array[Long](sizeRequired / 8) + val numBytesWritten = converter.writeRow( + rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET) + numBytesWritten should be (sizeRequired) + + val createdFromNull = new UnsafeRow() + createdFromNull.pointTo( + createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + for (i <- 0 to fieldTypes.length - 1) { + assert(createdFromNull.isNullAt(i)) + } + createdFromNull.getBoolean(1) should be (false) + createdFromNull.getByte(2) should be (0) + createdFromNull.getShort(3) should be (0) + createdFromNull.getInt(4) should be (0) + createdFromNull.getLong(5) should be (0) + assert(java.lang.Float.isNaN(createdFromNull.getFloat(6))) + assert(java.lang.Double.isNaN(createdFromNull.getFloat(7))) + + // If we have an UnsafeRow with columns that are initially non-null and we null out those + // columns, then the serialized row representation should be identical to what we would get by + // creating an entirely null row via the converter + val rowWithNoNullColumns: Row = { + val r = new SpecificMutableRow(fieldTypes) + r.setNullAt(0) + r.setBoolean(1, false) + r.setByte(2, 20) + r.setShort(3, 30) + r.setInt(4, 400) + r.setLong(5, 500) + r.setFloat(6, 600) + r.setDouble(7, 700) + r + } + val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8) + converter.writeRow( + rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET) + val setToNullAfterCreation = new UnsafeRow() + setToNullAfterCreation.pointTo( + setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + + setToNullAfterCreation.isNullAt(0) should be (rowWithNoNullColumns.isNullAt(0)) + setToNullAfterCreation.getBoolean(1) should be (rowWithNoNullColumns.getBoolean(1)) + setToNullAfterCreation.getByte(2) should be (rowWithNoNullColumns.getByte(2)) + setToNullAfterCreation.getShort(3) should be (rowWithNoNullColumns.getShort(3)) + setToNullAfterCreation.getInt(4) should be (rowWithNoNullColumns.getInt(4)) + setToNullAfterCreation.getLong(5) should be (rowWithNoNullColumns.getLong(5)) + setToNullAfterCreation.getFloat(6) should be (rowWithNoNullColumns.getFloat(6)) + setToNullAfterCreation.getDouble(7) should be (rowWithNoNullColumns.getDouble(7)) + + for (i <- 0 to fieldTypes.length - 1) { + setToNullAfterCreation.setNullAt(i) + } + assert(Arrays.equals(createdFromNullBuffer, setToNullAfterCreationBuffer)) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala index 169125264a803..3e7cf7cbb5e63 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala @@ -23,13 +23,13 @@ class DataTypeParserSuite extends FunSuite { def checkDataType(dataTypeString: String, expectedDataType: DataType): Unit = { test(s"parse ${dataTypeString.replace("\n", "")}") { - assert(DataTypeParser(dataTypeString) === expectedDataType) + assert(DataTypeParser.parse(dataTypeString) === expectedDataType) } } def unsupported(dataTypeString: String): Unit = { test(s"$dataTypeString is not supported") { - intercept[DataTypeException](DataTypeParser(dataTypeString)) + intercept[DataTypeException](DataTypeParser.parse(dataTypeString)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index edb229c059e6b..33f9d0b37d006 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -647,7 +647,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * * @group expr_ops */ - def cast(to: String): Column = cast(DataTypeParser(to)) + def cast(to: String): Column = cast(DataTypeParser.parse(to)) /** * Returns an ordering used in sorting. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index ca6ae482eb2ab..0e896e5693b98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -623,17 +623,12 @@ class DataFrame private[sql]( } /** - * (Scala-specific) Compute aggregates by specifying a map from column name to - * aggregate methods. The resulting [[DataFrame]] will also contain the grouping columns. - * - * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. - * {{{ - * // Selects the age of the oldest employee and the aggregate expense for each department - * df.groupBy("department").agg( - * "age" -> "max", - * "expense" -> "sum" - * ) - * }}} + * (Scala-specific) Aggregates on the entire [[DataFrame]] without groups. + * {{ + * // df.agg(...) is a shorthand for df.groupBy().agg(...) + * df.agg("age" -> "max", "salary" -> "avg") + * df.groupBy().agg("age" -> "max", "salary" -> "avg") + * }} * @group dfops */ def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = { @@ -961,9 +956,7 @@ class DataFrame private[sql]( * @group rdd */ override def repartition(numPartitions: Int): DataFrame = { - sqlContext.createDataFrame( - queryExecution.toRdd.map(_.copy()).repartition(numPartitions), - schema, needsConversion = false) + Repartition(numPartitions, shuffle = true, logicalPlan) } /** @@ -974,10 +967,7 @@ class DataFrame private[sql]( * @group rdd */ override def coalesce(numPartitions: Int): DataFrame = { - sqlContext.createDataFrame( - queryExecution.toRdd.coalesce(numPartitions), - schema, - needsConversion = false) + Repartition(numPartitions, shuffle = false, logicalPlan) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 4fc5de7e824fe..2fa602a6082dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -30,6 +30,7 @@ private[spark] object SQLConf { val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes" val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" val CODEGEN_ENABLED = "spark.sql.codegen" + val UNSAFE_ENABLED = "spark.sql.unsafe.enabled" val DIALECT = "spark.sql.dialect" val PARQUET_BINARY_AS_STRING = "spark.sql.parquet.binaryAsString" @@ -149,6 +150,14 @@ private[sql] class SQLConf extends Serializable { */ private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, "false").toBoolean + /** + * When set to true, Spark SQL will use managed memory for certain operations. This option only + * takes effect if codegen is enabled. + * + * Defaults to false as this feature is currently experimental. + */ + private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED, "false").toBoolean + private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2, "true").toBoolean /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index a279b0f07c38a..bd4a55fa132fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1011,6 +1011,8 @@ class SQLContext(@transient val sparkContext: SparkContext) def codegenEnabled: Boolean = self.conf.codegenEnabled + def unsafeEnabled: Boolean = self.conf.unsafeEnabled + def numPartitions: Int = self.conf.numShufflePartitions def strategies: Seq[Strategy] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index b1ef6556de1e9..5d9f202681045 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import org.apache.spark.TaskContext import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.trees._ @@ -40,6 +41,7 @@ case class AggregateEvaluation( * ensure all values where `groupingExpressions` are equal are present. * @param groupingExpressions expressions that are evaluated to determine grouping. * @param aggregateExpressions expressions that are computed for each group. + * @param unsafeEnabled whether to allow Unsafe-based aggregation buffers to be used. * @param child the input data source. */ @DeveloperApi @@ -47,6 +49,7 @@ case class GeneratedAggregate( partial: Boolean, groupingExpressions: Seq[Expression], aggregateExpressions: Seq[NamedExpression], + unsafeEnabled: Boolean, child: SparkPlan) extends UnaryNode { @@ -225,6 +228,21 @@ case class GeneratedAggregate( case e: Expression if groupMap.contains(e) => groupMap(e) }) + val aggregationBufferSchema: StructType = StructType.fromAttributes(computationSchema) + + val groupKeySchema: StructType = { + val fields = groupingExpressions.zipWithIndex.map { case (expr, idx) => + // This is a dummy field name + StructField(idx.toString, expr.dataType, expr.nullable) + } + StructType(fields) + } + + val schemaSupportsUnsafe: Boolean = { + UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && + UnsafeFixedWidthAggregationMap.supportsGroupKeySchema(groupKeySchema) + } + child.execute().mapPartitions { iter => // Builds a new custom class for holding the results of aggregation for a group. val initialValues = computeFunctions.flatMap(_.initialValues) @@ -265,7 +283,49 @@ case class GeneratedAggregate( val resultProjection = resultProjectionBuilder() Iterator(resultProjection(buffer)) + } else if (unsafeEnabled && schemaSupportsUnsafe) { + log.info("Using Unsafe-based aggregator") + val aggregationMap = new UnsafeFixedWidthAggregationMap( + newAggregationBuffer(EmptyRow), + aggregationBufferSchema, + groupKeySchema, + TaskContext.get.taskMemoryManager(), + 1024 * 16, // initial capacity + false // disable tracking of performance metrics + ) + + while (iter.hasNext) { + val currentRow: Row = iter.next() + val groupKey: Row = groupProjection(currentRow) + val aggregationBuffer = aggregationMap.getAggregationBuffer(groupKey) + updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow)) + } + + new Iterator[Row] { + private[this] val mapIterator = aggregationMap.iterator() + private[this] val resultProjection = resultProjectionBuilder() + + def hasNext: Boolean = mapIterator.hasNext + + def next(): Row = { + val entry = mapIterator.next() + val result = resultProjection(joinedRow(entry.key, entry.value)) + if (hasNext) { + result + } else { + // This is the last element in the iterator, so let's free the buffer. Before we do, + // though, we need to make a defensive copy of the result so that we don't return an + // object that might contain dangling pointers to the freed memory + val resultCopy = result.copy() + aggregationMap.free() + resultCopy + } + } + } } else { + if (unsafeEnabled) { + log.info("Not using Unsafe-based aggregator because it is not supported for this schema") + } val buffers = new java.util.HashMap[Row, MutableRow]() var currentRow: Row = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 030ef118f75d4..af58911cc0e6a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -136,10 +136,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { partial = false, namedGroupingAttributes, rewrittenAggregateExpressions, + unsafeEnabled, execution.GeneratedAggregate( partial = true, groupingExpressions, partialComputation, + unsafeEnabled, planLater(child))) :: Nil // Cases where some aggregate can not be codegened @@ -283,7 +285,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Distinct(child) => execution.Distinct(partial = false, execution.Distinct(partial = true, planLater(child))) :: Nil - + case logical.Repartition(numPartitions, shuffle, child) => + execution.Repartition(numPartitions, shuffle, planLater(child)) :: Nil case logical.SortPartitions(sortExprs, child) => // This sort only sorts tuples within a partition. Its requiredDistribution will be // an UnspecifiedDistribution. @@ -317,7 +320,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { generator, join = join, outer = outer, g.output, planLater(child)) :: Nil case logical.OneRowRelation => execution.PhysicalRDD(Nil, singleRowRdd) :: Nil - case logical.Repartition(expressions, child) => + case logical.RepartitionByExpression(expressions, child) => execution.Exchange( HashPartitioning(expressions, numPartitions), Nil, planLater(child)) :: Nil case e @ EvaluatePython(udf, child, _) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index d286fe81bee5f..1afdb409417ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -245,6 +245,20 @@ case class Distinct(partial: Boolean, child: SparkPlan) extends UnaryNode { } } +/** + * :: DeveloperApi :: + * Return a new RDD that has exactly `numPartitions` partitions. + */ +@DeveloperApi +case class Repartition(numPartitions: Int, shuffle: Boolean, child: SparkPlan) + extends UnaryNode { + override def output: Seq[Attribute] = child.output + + override def execute(): RDD[Row] = { + child.execute().map(_.copy()).coalesce(numPartitions, shuffle) + } +} + /** * :: DeveloperApi :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala new file mode 100644 index 0000000000000..9ac732b55b188 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.expressions + +import org.apache.spark.TaskContext +import org.apache.spark.sql.catalyst.expressions.{Row, LeafExpression} +import org.apache.spark.sql.types.{LongType, DataType} + +/** + * Returns monotonically increasing 64-bit integers. + * + * The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. + * The current implementation puts the partition ID in the upper 31 bits, and the lower 33 bits + * represent the record number within each partition. The assumption is that the data frame has + * less than 1 billion partitions, and each partition has less than 8 billion records. + * + * Since this expression is stateful, it cannot be a case object. + */ +private[sql] case class MonotonicallyIncreasingID() extends LeafExpression { + + /** + * Record ID within each partition. By being transient, count's value is reset to 0 every time + * we serialize and deserialize it. + */ + @transient private[this] var count: Long = 0L + + override type EvaluatedType = Long + + override def nullable: Boolean = false + + override def dataType: DataType = LongType + + override def eval(input: Row): Long = { + val currentCount = count + count += 1 + (TaskContext.get().partitionId().toLong << 33) + currentCount + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala new file mode 100644 index 0000000000000..c2c6cbd491598 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.expressions + +import org.apache.spark.TaskContext +import org.apache.spark.sql.catalyst.expressions.{LeafExpression, Row} +import org.apache.spark.sql.types.{IntegerType, DataType} + + +/** + * Expression that returns the current partition id of the Spark task. + */ +private[sql] case object SparkPartitionID extends LeafExpression { + + override type EvaluatedType = Int + + override def nullable: Boolean = false + + override def dataType: DataType = IntegerType + + override def eval(input: Row): Int = TaskContext.get().partitionId() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/package.scala new file mode 100644 index 0000000000000..568b7ac2c5987 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/package.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +/** + * Package containing expressions that are specific to Spark runtime. + */ +package object expressions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index ff91e1d74bc2c..aa31d04a0cbe4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -276,6 +276,13 @@ object functions { // Non-aggregate functions ////////////////////////////////////////////////////////////////////////////////////////////// + /** + * Computes the absolute value. + * + * @group normal_funcs + */ + def abs(e: Column): Column = Abs(e.expr) + /** * Returns the first column that is not null. * {{{ @@ -287,6 +294,29 @@ object functions { @scala.annotation.varargs def coalesce(e: Column*): Column = Coalesce(e.map(_.expr)) + /** + * Converts a string exprsesion to lower case. + * + * @group normal_funcs + */ + def lower(e: Column): Column = Lower(e.expr) + + /** + * A column expression that generates monotonically increasing 64-bit integers. + * + * The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. + * The current implementation puts the partition ID in the upper 31 bits, and the record number + * within each partition in the lower 33 bits. The assumption is that the data frame has + * less than 1 billion partitions, and each partition has less than 8 billion records. + * + * As an example, consider a [[DataFrame]] with two partitions, each with 3 records. + * This expression would return the following IDs: + * 0, 1, 2, 8589934592 (1L << 33), 8589934593, 8589934594. + * + * @group normal_funcs + */ + def monotonicallyIncreasingId(): Column = execution.expressions.MonotonicallyIncreasingID() + /** * Unary minus, i.e. negate the expression. * {{{ @@ -317,18 +347,13 @@ object functions { def not(e: Column): Column = !e /** - * Converts a string expression to upper case. + * Partition ID of the Spark task. * - * @group normal_funcs - */ - def upper(e: Column): Column = Upper(e.expr) - - /** - * Converts a string exprsesion to lower case. + * Note that this is indeterministic because it depends on data partitioning and task scheduling. * * @group normal_funcs */ - def lower(e: Column): Column = Lower(e.expr) + def sparkPartitionId(): Column = execution.expressions.SparkPartitionID /** * Computes the square root of the specified float value. @@ -338,11 +363,11 @@ object functions { def sqrt(e: Column): Column = Sqrt(e.expr) /** - * Computes the absolutle value. + * Converts a string expression to upper case. * * @group normal_funcs */ - def abs(e: Column): Column = Abs(e.expr) + def upper(e: Column): Column = Upper(e.expr) ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index f326510042122..f3b5455574d1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.jdbc import java.sql.{Connection, DriverManager, ResultSet, ResultSetMetaData, SQLException} import java.util.Properties -import org.apache.commons.lang.StringEscapeUtils.escapeSql +import org.apache.commons.lang3.StringUtils + import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{Row, SpecificMutableRow} @@ -239,6 +240,9 @@ private[sql] class JDBCRDD( case _ => value } + private def escapeSql(value: String): String = + if (value == null) null else StringUtils.replace(value, "'", "''") + /** * Turns a single Filter into a String representing a SQL expression. * Returns null for an unhandled filter. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala new file mode 100644 index 0000000000000..d901542b7eaf3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala @@ -0,0 +1,385 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.implicitConversions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.mathfuncs._ +import org.apache.spark.sql.functions.lit + +/** + * :: Experimental :: + * Mathematical Functions available for [[DataFrame]]. + * + * @groupname double_funcs Functions that require DoubleType as an input + */ +@Experimental +// scalastyle:off +object mathfunctions { +// scalastyle:on + + private[this] implicit def toColumn(expr: Expression): Column = Column(expr) + + /** + * Computes the cosine inverse of the given value; the returned angle is in the range + * 0.0 through pi. + */ + def acos(e: Column): Column = Acos(e.expr) + + /** + * Computes the cosine inverse of the given column; the returned angle is in the range + * 0.0 through pi. + */ + def acos(columnName: String): Column = acos(Column(columnName)) + + /** + * Computes the sine inverse of the given value; the returned angle is in the range + * -pi/2 through pi/2. + */ + def asin(e: Column): Column = Asin(e.expr) + + /** + * Computes the sine inverse of the given column; the returned angle is in the range + * -pi/2 through pi/2. + */ + def asin(columnName: String): Column = asin(Column(columnName)) + + /** + * Computes the tangent inverse of the given value. + */ + def atan(e: Column): Column = Atan(e.expr) + + /** + * Computes the tangent inverse of the given column. + */ + def atan(columnName: String): Column = atan(Column(columnName)) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + */ + def atan2(l: Column, r: Column): Column = Atan2(l.expr, r.expr) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + */ + def atan2(l: Column, rightName: String): Column = atan2(l, Column(rightName)) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + */ + def atan2(leftName: String, r: Column): Column = atan2(Column(leftName), r) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + */ + def atan2(leftName: String, rightName: String): Column = + atan2(Column(leftName), Column(rightName)) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + */ + def atan2(l: Column, r: Double): Column = atan2(l, lit(r).expr) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta).= + */ + def atan2(leftName: String, r: Double): Column = atan2(Column(leftName), r) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + */ + def atan2(l: Double, r: Column): Column = atan2(lit(l).expr, r) + + /** + * Returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + */ + def atan2(l: Double, rightName: String): Column = atan2(l, Column(rightName)) + + /** + * Computes the cube-root of the given value. + */ + def cbrt(e: Column): Column = Cbrt(e.expr) + + /** + * Computes the cube-root of the given column. + */ + def cbrt(columnName: String): Column = cbrt(Column(columnName)) + + /** + * Computes the ceiling of the given value. + */ + def ceil(e: Column): Column = Ceil(e.expr) + + /** + * Computes the ceiling of the given column. + */ + def ceil(columnName: String): Column = ceil(Column(columnName)) + + /** + * Computes the cosine of the given value. + */ + def cos(e: Column): Column = Cos(e.expr) + + /** + * Computes the cosine of the given column. + */ + def cos(columnName: String): Column = cos(Column(columnName)) + + /** + * Computes the hyperbolic cosine of the given value. + */ + def cosh(e: Column): Column = Cosh(e.expr) + + /** + * Computes the hyperbolic cosine of the given column. + */ + def cosh(columnName: String): Column = cosh(Column(columnName)) + + /** + * Computes the exponential of the given value. + */ + def exp(e: Column): Column = Exp(e.expr) + + /** + * Computes the exponential of the given column. + */ + def exp(columnName: String): Column = exp(Column(columnName)) + + /** + * Computes the exponential of the given value minus one. + */ + def expm1(e: Column): Column = Expm1(e.expr) + + /** + * Computes the exponential of the given column. + */ + def expm1(columnName: String): Column = expm1(Column(columnName)) + + /** + * Computes the floor of the given value. + */ + def floor(e: Column): Column = Floor(e.expr) + + /** + * Computes the floor of the given column. + */ + def floor(columnName: String): Column = floor(Column(columnName)) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + */ + def hypot(l: Column, r: Column): Column = Hypot(l.expr, r.expr) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + */ + def hypot(l: Column, rightName: String): Column = hypot(l, Column(rightName)) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + */ + def hypot(leftName: String, r: Column): Column = hypot(Column(leftName), r) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + */ + def hypot(leftName: String, rightName: String): Column = + hypot(Column(leftName), Column(rightName)) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + */ + def hypot(l: Column, r: Double): Column = hypot(l, lit(r).expr) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + */ + def hypot(leftName: String, r: Double): Column = hypot(Column(leftName), r) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + */ + def hypot(l: Double, r: Column): Column = hypot(lit(l).expr, r) + + /** + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + */ + def hypot(l: Double, rightName: String): Column = hypot(l, Column(rightName)) + + /** + * Computes the natural logarithm of the given value. + */ + def log(e: Column): Column = Log(e.expr) + + /** + * Computes the natural logarithm of the given column. + */ + def log(columnName: String): Column = log(Column(columnName)) + + /** + * Computes the logarithm of the given value in Base 10. + */ + def log10(e: Column): Column = Log10(e.expr) + + /** + * Computes the logarithm of the given value in Base 10. + */ + def log10(columnName: String): Column = log10(Column(columnName)) + + /** + * Computes the natural logarithm of the given value plus one. + */ + def log1p(e: Column): Column = Log1p(e.expr) + + /** + * Computes the natural logarithm of the given column plus one. + */ + def log1p(columnName: String): Column = log1p(Column(columnName)) + + /** + * Returns the value of the first argument raised to the power of the second argument. + */ + def pow(l: Column, r: Column): Column = Pow(l.expr, r.expr) + + /** + * Returns the value of the first argument raised to the power of the second argument. + */ + def pow(l: Column, rightName: String): Column = pow(l, Column(rightName)) + + /** + * Returns the value of the first argument raised to the power of the second argument. + */ + def pow(leftName: String, r: Column): Column = pow(Column(leftName), r) + + /** + * Returns the value of the first argument raised to the power of the second argument. + */ + def pow(leftName: String, rightName: String): Column = pow(Column(leftName), Column(rightName)) + + /** + * Returns the value of the first argument raised to the power of the second argument. + */ + def pow(l: Column, r: Double): Column = pow(l, lit(r).expr) + + /** + * Returns the value of the first argument raised to the power of the second argument. + */ + def pow(leftName: String, r: Double): Column = pow(Column(leftName), r) + + /** + * Returns the value of the first argument raised to the power of the second argument. + */ + def pow(l: Double, r: Column): Column = pow(lit(l).expr, r) + + /** + * Returns the value of the first argument raised to the power of the second argument. + */ + def pow(l: Double, rightName: String): Column = pow(l, Column(rightName)) + + /** + * Returns the double value that is closest in value to the argument and + * is equal to a mathematical integer. + */ + def rint(e: Column): Column = Rint(e.expr) + + /** + * Returns the double value that is closest in value to the argument and + * is equal to a mathematical integer. + */ + def rint(columnName: String): Column = rint(Column(columnName)) + + /** + * Computes the signum of the given value. + */ + def signum(e: Column): Column = Signum(e.expr) + + /** + * Computes the signum of the given column. + */ + def signum(columnName: String): Column = signum(Column(columnName)) + + /** + * Computes the sine of the given value. + */ + def sin(e: Column): Column = Sin(e.expr) + + /** + * Computes the sine of the given column. + */ + def sin(columnName: String): Column = sin(Column(columnName)) + + /** + * Computes the hyperbolic sine of the given value. + */ + def sinh(e: Column): Column = Sinh(e.expr) + + /** + * Computes the hyperbolic sine of the given column. + */ + def sinh(columnName: String): Column = sinh(Column(columnName)) + + /** + * Computes the tangent of the given value. + */ + def tan(e: Column): Column = Tan(e.expr) + + /** + * Computes the tangent of the given column. + */ + def tan(columnName: String): Column = tan(Column(columnName)) + + /** + * Computes the hyperbolic tangent of the given value. + */ + def tanh(e: Column): Column = Tanh(e.expr) + + /** + * Computes the hyperbolic tangent of the given column. + */ + def tanh(columnName: String): Column = tanh(Column(columnName)) + + /** + * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. + */ + def toDeg(e: Column): Column = ToDegrees(e.expr) + + /** + * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. + */ + def toDeg(columnName: String): Column = toDeg(Column(columnName)) + + /** + * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. + */ + def toRad(e: Column): Column = ToRadians(e.expr) + + /** + * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. + */ + def toRad(columnName: String): Column = toRad(Column(columnName)) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala new file mode 100644 index 0000000000000..f5ce2718bec4a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/DirectParquetOutputCommitter.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter + +import parquet.Log +import parquet.hadoop.util.ContextUtil +import parquet.hadoop.{ParquetFileReader, ParquetFileWriter, ParquetOutputCommitter, ParquetOutputFormat} + +private[parquet] class DirectParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) + extends ParquetOutputCommitter(outputPath, context) { + val LOG = Log.getLog(classOf[ParquetOutputCommitter]) + + override def getWorkPath(): Path = outputPath + override def abortTask(taskContext: TaskAttemptContext): Unit = {} + override def commitTask(taskContext: TaskAttemptContext): Unit = {} + override def needsTaskCommit(taskContext: TaskAttemptContext): Boolean = true + override def setupJob(jobContext: JobContext): Unit = {} + override def setupTask(taskContext: TaskAttemptContext): Unit = {} + + override def commitJob(jobContext: JobContext) { + val configuration = ContextUtil.getConfiguration(jobContext) + val fileSystem = outputPath.getFileSystem(configuration) + + if (configuration.getBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, true)) { + try { + val outputStatus = fileSystem.getFileStatus(outputPath) + val footers = ParquetFileReader.readAllFootersInParallel(configuration, outputStatus) + try { + ParquetFileWriter.writeMetadataFile(configuration, outputPath, footers) + } catch { + case e: Exception => { + LOG.warn("could not write summary file for " + outputPath, e) + val metadataPath = new Path(outputPath, ParquetFileWriter.PARQUET_METADATA_FILE) + if (fileSystem.exists(metadataPath)) { + fileSystem.delete(metadataPath, true) + } + } + } + } catch { + case e: Exception => LOG.warn("could not write summary file for " + outputPath, e) + } + } + + if (configuration.getBoolean("mapreduce.fileoutputcommitter.marksuccessfuljobs", true)) { + try { + val successPath = new Path(outputPath, FileOutputCommitter.SUCCEEDED_FILE_NAME) + fileSystem.create(successPath).close() + } catch { + case e: Exception => LOG.warn("could not write success file for " + outputPath, e) + } + } + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index a938b77578686..aded126ea0615 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -381,6 +381,7 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int) extends parquet.hadoop.ParquetOutputFormat[Row] { // override to accept existing directories as valid output directory override def checkOutputSpecs(job: JobContext): Unit = {} + var committer: OutputCommitter = null // override to choose output filename so not overwrite existing ones override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { @@ -403,6 +404,26 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int) private def getTaskAttemptID(context: TaskAttemptContext): TaskAttemptID = { context.getClass.getMethod("getTaskAttemptID").invoke(context).asInstanceOf[TaskAttemptID] } + + // override to create output committer from configuration + override def getOutputCommitter(context: TaskAttemptContext): OutputCommitter = { + if (committer == null) { + val output = getOutputPath(context) + val cls = context.getConfiguration.getClass("spark.sql.parquet.output.committer.class", + classOf[ParquetOutputCommitter], classOf[ParquetOutputCommitter]) + val ctor = cls.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) + committer = ctor.newInstance(output, context).asInstanceOf[ParquetOutputCommitter] + } + committer + } + + // FileOutputFormat.getOutputPath takes JobConf in hadoop-1 but JobContext in hadoop-2 + private def getOutputPath(context: TaskAttemptContext): Path = { + context.getConfiguration().get("mapred.output.dir") match { + case null => null + case name => new Path(name) + } + } } /** diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index e02c84872c628..e5c9504d21042 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -41,6 +41,7 @@ import java.util.Map; import static org.apache.spark.sql.functions.*; +import static org.apache.spark.sql.mathfunctions.*; public class JavaDataFrameSuite { private transient JavaSparkContext jsc; @@ -98,6 +99,14 @@ public void testVarargMethods() { df.groupBy().agg(countDistinct("key", "value")); df.groupBy().agg(countDistinct(col("key"), col("value"))); df.select(coalesce(col("key"))); + + // Varargs with mathfunctions + DataFrame df2 = context.table("testData2"); + df2.select(exp("a"), exp("b")); + df2.select(exp(log("a"))); + df2.select(pow("a", "a"), pow("b", 2.0)); + df2.select(pow(col("a"), col("b")), exp("b")); + df2.select(sin("a"), acos("b")); } @Ignore diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index bc8fae100db6a..2ba5fc21ff57c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ - class ColumnExpressionSuite extends QueryTest { import org.apache.spark.sql.TestData._ @@ -310,6 +309,25 @@ class ColumnExpressionSuite extends QueryTest { ) } + test("monotonicallyIncreasingId") { + // Make sure we have 2 partitions, each with 2 records. + val df = TestSQLContext.sparkContext.parallelize(1 to 2, 2).mapPartitions { iter => + Iterator(Tuple1(1), Tuple1(2)) + }.toDF("a") + checkAnswer( + df.select(monotonicallyIncreasingId()), + Row(0L) :: Row(1L) :: Row((1L << 33) + 0L) :: Row((1L << 33) + 1L) :: Nil + ) + } + + test("sparkPartitionId") { + val df = TestSQLContext.sparkContext.parallelize(1 to 1, 1).map(i => (i, i)).toDF("a", "b") + checkAnswer( + df.select(sparkPartitionId()), + Row(0) + ) + } + test("lift alias out of cast") { compareExpressions( col("1234").as("name").cast("int").expr, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala new file mode 100644 index 0000000000000..9e19bb7482e9b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.lang.{Double => JavaDouble} + +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.mathfunctions._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext.implicits._ + +private[this] object MathExpressionsTestData { + + case class DoubleData(a: JavaDouble, b: JavaDouble) + val doubleData = TestSQLContext.sparkContext.parallelize( + (1 to 10).map(i => DoubleData(i * 0.2 - 1, i * -0.2 + 1))).toDF() + + val nnDoubleData = TestSQLContext.sparkContext.parallelize( + (1 to 10).map(i => DoubleData(i * 0.1, i * -0.1))).toDF() + + case class NullDoubles(a: JavaDouble) + val nullDoubles = + TestSQLContext.sparkContext.parallelize( + NullDoubles(1.0) :: + NullDoubles(2.0) :: + NullDoubles(3.0) :: + NullDoubles(null) :: Nil + ).toDF() +} + +class MathExpressionsSuite extends QueryTest { + + import MathExpressionsTestData._ + + def testOneToOneMathFunction[@specialized(Int, Long, Float, Double) T]( + c: Column => Column, + f: T => T): Unit = { + checkAnswer( + doubleData.select(c('a)), + (1 to 10).map(n => Row(f((n * 0.2 - 1).asInstanceOf[T]))) + ) + + checkAnswer( + doubleData.select(c('b)), + (1 to 10).map(n => Row(f((-n * 0.2 + 1).asInstanceOf[T]))) + ) + + checkAnswer( + doubleData.select(c(lit(null))), + (1 to 10).map(_ => Row(null)) + ) + } + + def testOneToOneNonNegativeMathFunction(c: Column => Column, f: Double => Double): Unit = { + checkAnswer( + nnDoubleData.select(c('a)), + (1 to 10).map(n => Row(f(n * 0.1))) + ) + + if (f(-1) === math.log1p(-1)) { + checkAnswer( + nnDoubleData.select(c('b)), + (1 to 9).map(n => Row(f(n * -0.1))) :+ Row(Double.NegativeInfinity) + ) + } else { + checkAnswer( + nnDoubleData.select(c('b)), + (1 to 10).map(n => Row(null)) + ) + } + + checkAnswer( + nnDoubleData.select(c(lit(null))), + (1 to 10).map(_ => Row(null)) + ) + } + + def testTwoToOneMathFunction( + c: (Column, Column) => Column, + d: (Column, Double) => Column, + f: (Double, Double) => Double): Unit = { + checkAnswer( + nnDoubleData.select(c('a, 'a)), + nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), r.getDouble(0)))) + ) + + checkAnswer( + nnDoubleData.select(c('a, 'b)), + nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), r.getDouble(1)))) + ) + + checkAnswer( + nnDoubleData.select(d('a, 2.0)), + nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), 2.0))) + ) + + checkAnswer( + nnDoubleData.select(d('a, -0.5)), + nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), -0.5))) + ) + + val nonNull = nullDoubles.collect().toSeq.filter(r => r.get(0) != null) + + checkAnswer( + nullDoubles.select(c('a, 'a)).orderBy('a.asc), + Row(null) +: nonNull.map(r => Row(f(r.getDouble(0), r.getDouble(0)))) + ) + } + + test("sin") { + testOneToOneMathFunction(sin, math.sin) + } + + test("asin") { + testOneToOneMathFunction(asin, math.asin) + } + + test("sinh") { + testOneToOneMathFunction(sinh, math.sinh) + } + + test("cos") { + testOneToOneMathFunction(cos, math.cos) + } + + test("acos") { + testOneToOneMathFunction(acos, math.acos) + } + + test("cosh") { + testOneToOneMathFunction(cosh, math.cosh) + } + + test("tan") { + testOneToOneMathFunction(tan, math.tan) + } + + test("atan") { + testOneToOneMathFunction(atan, math.atan) + } + + test("tanh") { + testOneToOneMathFunction(tanh, math.tanh) + } + + test("toDeg") { + testOneToOneMathFunction(toDeg, math.toDegrees) + } + + test("toRad") { + testOneToOneMathFunction(toRad, math.toRadians) + } + + test("cbrt") { + testOneToOneMathFunction(cbrt, math.cbrt) + } + + test("ceil") { + testOneToOneMathFunction(ceil, math.ceil) + } + + test("floor") { + testOneToOneMathFunction(floor, math.floor) + } + + test("rint") { + testOneToOneMathFunction(rint, math.rint) + } + + test("exp") { + testOneToOneMathFunction(exp, math.exp) + } + + test("expm1") { + testOneToOneMathFunction(expm1, math.expm1) + } + + test("signum") { + testOneToOneMathFunction[Double](signum, math.signum) + } + + test("pow") { + testTwoToOneMathFunction(pow, pow, math.pow) + } + + test("hypot") { + testTwoToOneMathFunction(hypot, hypot, math.hypot) + } + + test("atan2") { + testTwoToOneMathFunction(atan2, atan2, math.atan2) + } + + test("log") { + testOneToOneNonNegativeMathFunction(log, math.log) + } + + test("log10") { + testOneToOneNonNegativeMathFunction(log10, math.log10) + } + + test("log1p") { + testOneToOneNonNegativeMathFunction(log1p, math.log1p) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index 97c0f439acf13..b504842053690 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -381,6 +381,28 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } } } + + test("SPARK-6352 DirectParquetOutputCommitter") { + // Write to a parquet file and let it fail. + // _temporary should be missing if direct output committer works. + try { + configuration.set("spark.sql.parquet.output.committer.class", + "org.apache.spark.sql.parquet.DirectParquetOutputCommitter") + sqlContext.udf.register("div0", (x: Int) => x / 0) + withTempPath { dir => + intercept[org.apache.spark.SparkException] { + sqlContext.sql("select div0(1)").saveAsParquetFile(dir.getCanonicalPath) + } + val path = new Path(dir.getCanonicalPath, "_temporary") + val fs = path.getFileSystem(configuration) + assert(!fs.exists(path)) + } + } + finally { + configuration.set("spark.sql.parquet.output.committer.class", + "parquet.hadoop.ParquetOutputCommitter") + } + } } class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala index 59f3a75768082..48ac9062af96a 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/AbstractSparkSQLDriver.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive.thriftserver import scala.collection.JavaConversions._ -import org.apache.commons.lang.exception.ExceptionUtils +import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema} import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse @@ -61,7 +61,7 @@ private[hive] abstract class AbstractSparkSQLDriver( } catch { case cause: Throwable => logError(s"Failed in [$command]", cause) - new CommandProcessorResponse(1, ExceptionUtils.getFullStackTrace(cause), null) + new CommandProcessorResponse(1, ExceptionUtils.getStackTrace(cause), null) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 7e307bb4ad1e8..b7b6925aa87f7 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -24,18 +24,16 @@ import java.util.{ArrayList => JArrayList} import jline.{ConsoleReader, History} -import org.apache.commons.lang.StringUtils +import org.apache.commons.lang3.StringUtils import org.apache.commons.logging.LogFactory import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.cli.{CliDriver, CliSessionState, OptionsProcessor} -import org.apache.hadoop.hive.common.LogUtils.LogInitializationException -import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils, LogUtils} +import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.exec.Utilities -import org.apache.hadoop.hive.ql.processors.{AddResourceProcessor, SetProcessor, CommandProcessor, CommandProcessorFactory} +import org.apache.hadoop.hive.ql.processors.{AddResourceProcessor, SetProcessor, CommandProcessor} import org.apache.hadoop.hive.ql.session.SessionState -import org.apache.hadoop.hive.shims.ShimLoader import org.apache.thrift.transport.TSocket import org.apache.spark.Logging diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 21dce8d8a565a..e322340094e6f 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -183,7 +183,6 @@ org.apache.maven.plugins maven-dependency-plugin - 2.4 copy-dependencies diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index f1c0bd92aa23d..4d222cf88e5e8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -871,7 +871,7 @@ private[hive] case class MetastoreRelation private[hive] object HiveMetastoreTypes { - def toDataType(metastoreType: String): DataType = DataTypeParser(metastoreType) + def toDataType(metastoreType: String): DataType = DataTypeParser.parse(metastoreType) def toMetastoreType(dt: DataType): String = dt match { case ArrayType(elementType, _) => s"array<${toMetastoreType(elementType)}>" diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 0ea6d57b816c6..2dc6463abafa7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -783,13 +783,13 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case (None, Some(perPartitionOrdering), None, None) => Sort(perPartitionOrdering.getChildren.map(nodeToSortOrder), false, withHaving) case (None, None, Some(partitionExprs), None) => - Repartition(partitionExprs.getChildren.map(nodeToExpr), withHaving) + RepartitionByExpression(partitionExprs.getChildren.map(nodeToExpr), withHaving) case (None, Some(perPartitionOrdering), Some(partitionExprs), None) => Sort(perPartitionOrdering.getChildren.map(nodeToSortOrder), false, - Repartition(partitionExprs.getChildren.map(nodeToExpr), withHaving)) + RepartitionByExpression(partitionExprs.getChildren.map(nodeToExpr), withHaving)) case (None, None, None, Some(clusterExprs)) => Sort(clusterExprs.getChildren.map(nodeToExpr).map(SortOrder(_, Ascending)), false, - Repartition(clusterExprs.getChildren.map(nodeToExpr), withHaving)) + RepartitionByExpression(clusterExprs.getChildren.map(nodeToExpr), withHaving)) case (None, None, None, None) => withHaving case _ => sys.error("Unsupported set of ordering / distribution clauses.") } diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java index efd34df293c88..f33210ebdae1b 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFListString.java @@ -17,10 +17,10 @@ package org.apache.spark.sql.hive.execution; -import org.apache.hadoop.hive.ql.exec.UDF; - import java.util.List; -import org.apache.commons.lang.StringUtils; + +import org.apache.commons.lang3.StringUtils; +import org.apache.hadoop.hive.ql.exec.UDF; public class UDFListString extends UDF { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index e09c702c8969e..0538aa203c5a0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -23,7 +23,6 @@ import scala.collection.mutable.ArrayBuffer import org.scalatest.BeforeAndAfterEach -import org.apache.commons.io.FileUtils import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.metastore.TableType import org.apache.hadoop.hive.ql.metadata.Table @@ -174,7 +173,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { sql("SELECT * FROM jsonTable"), Row("a", "b")) - FileUtils.deleteDirectory(tempDir) + Utils.deleteRecursively(tempDir) sparkContext.parallelize(("a1", "b1", "c1") :: Nil).toDF() .toJSON.saveAsTextFile(tempDir.getCanonicalPath) @@ -190,7 +189,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { checkAnswer( sql("SELECT * FROM jsonTable"), Row("a1", "b1", "c1")) - FileUtils.deleteDirectory(tempDir) + Utils.deleteRecursively(tempDir) } test("drop, change, recreate") { @@ -212,7 +211,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { sql("SELECT * FROM jsonTable"), Row("a", "b")) - FileUtils.deleteDirectory(tempDir) + Utils.deleteRecursively(tempDir) sparkContext.parallelize(("a", "b", "c") :: Nil).toDF() .toJSON.saveAsTextFile(tempDir.getCanonicalPath) @@ -231,7 +230,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach { checkAnswer( sql("SELECT * FROM jsonTable"), Row("a", "b", "c")) - FileUtils.deleteDirectory(tempDir) + Utils.deleteRecursively(tempDir) } test("invalidate cache and reload") { diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala index d331c210e8939..dbc5e029e2047 100644 --- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala +++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala @@ -19,11 +19,15 @@ package org.apache.spark.sql.hive import java.rmi.server.UID import java.util.{Properties, ArrayList => JArrayList} +import java.io.{OutputStream, InputStream} import scala.collection.JavaConversions._ import scala.language.implicitConversions +import scala.reflect.ClassTag import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.io.Input +import com.esotericsoftware.kryo.io.Output import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.common.StatsSetupConst @@ -46,6 +50,7 @@ import org.apache.hadoop.{io => hadoopIo} import org.apache.spark.Logging import org.apache.spark.sql.types.{Decimal, DecimalType, UTF8String} +import org.apache.spark.util.Utils._ /** * This class provides the UDF creation and also the UDF instance serialization and @@ -61,39 +66,34 @@ private[hive] case class HiveFunctionWrapper(var functionClassName: String) // for Serialization def this() = this(null) - import org.apache.spark.util.Utils._ - @transient - private val methodDeSerialize = { - val method = classOf[Utilities].getDeclaredMethod( - "deserializeObjectByKryo", - classOf[Kryo], - classOf[java.io.InputStream], - classOf[Class[_]]) - method.setAccessible(true) - - method + def deserializeObjectByKryo[T: ClassTag]( + kryo: Kryo, + in: InputStream, + clazz: Class[_]): T = { + val inp = new Input(in) + val t: T = kryo.readObject(inp,clazz).asInstanceOf[T] + inp.close() + t } @transient - private val methodSerialize = { - val method = classOf[Utilities].getDeclaredMethod( - "serializeObjectByKryo", - classOf[Kryo], - classOf[Object], - classOf[java.io.OutputStream]) - method.setAccessible(true) - - method + def serializeObjectByKryo( + kryo: Kryo, + plan: Object, + out: OutputStream ) { + val output: Output = new Output(out) + kryo.writeObject(output, plan) + output.close() } def deserializePlan[UDFType](is: java.io.InputStream, clazz: Class[_]): UDFType = { - methodDeSerialize.invoke(null, Utilities.runtimeSerializationKryo.get(), is, clazz) + deserializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), is, clazz) .asInstanceOf[UDFType] } def serializePlan(function: AnyRef, out: java.io.OutputStream): Unit = { - methodSerialize.invoke(null, Utilities.runtimeSerializationKryo.get(), function, out) + serializeObjectByKryo(Utilities.runtimeSerializationKryo.get(), function, out) } private var instance: AnyRef = null diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index f4963a78e1d18..4bebcc5aa7ca0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -126,6 +126,20 @@ private[streaming] class BlockGenerator( listener.onAddData(data, metadata) } + /** + * Push multiple data items into the buffer. After buffering the data, the + * `BlockGeneratorListener.onAddData` callback will be called. All received data items + * will be periodically pushed into BlockManager. Note that all the data items is guaranteed + * to be present in a single block. + */ + def addMultipleDataWithCallback(dataIterator: Iterator[Any], metadata: Any): Unit = synchronized { + dataIterator.foreach { data => + waitToPush() + currentBuffer += data + } + listener.onAddData(dataIterator, metadata) + } + /** Change the buffer to which single records are added to. */ private def updateCurrentBuffer(time: Long): Unit = synchronized { try { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 89af40330b9d9..f2379366f36ef 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -146,7 +146,7 @@ private[streaming] class ReceiverSupervisorImpl( logDebug(s"Pushed block $blockId in ${(System.currentTimeMillis - time)} ms") val blockInfo = ReceivedBlockInfo(streamId, numRecords, blockStoreResult) - trackerEndpoint.askWithReply[Boolean](AddBlock(blockInfo)) + trackerEndpoint.askWithRetry[Boolean](AddBlock(blockInfo)) logDebug(s"Reported block $blockId") } @@ -169,13 +169,13 @@ private[streaming] class ReceiverSupervisorImpl( override protected def onReceiverStart() { val msg = RegisterReceiver( streamId, receiver.getClass.getSimpleName, Utils.localHostName(), endpoint) - trackerEndpoint.askWithReply[Boolean](msg) + trackerEndpoint.askWithRetry[Boolean](msg) } override protected def onReceiverStop(message: String, error: Option[Throwable]) { logInfo("Deregistering receiver " + streamId) val errorString = error.map(Throwables.getStackTraceAsString).getOrElse("") - trackerEndpoint.askWithReply[Boolean](DeregisterReceiver(streamId, message, errorString)) + trackerEndpoint.askWithRetry[Boolean](DeregisterReceiver(streamId, message, errorString)) logInfo("Stopped receiver " + streamId) } diff --git a/unsafe/pom.xml b/unsafe/pom.xml new file mode 100644 index 0000000000000..8901d77591932 --- /dev/null +++ b/unsafe/pom.xml @@ -0,0 +1,69 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.10 + 1.4.0-SNAPSHOT + ../pom.xml + + + org.apache.spark + spark-unsafe_2.10 + jar + Spark Project Unsafe + http://spark.apache.org/ + + unsafe + + + + + + + com.google.code.findbugs + jsr305 + + + + + org.slf4j + slf4j-api + provided + + + + + junit + junit + test + + + com.novocode + junit-interface + test + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java b/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java new file mode 100644 index 0000000000000..91b2f9aa43921 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe; + +import java.lang.reflect.Field; + +import sun.misc.Unsafe; + +public final class PlatformDependent { + + public static final Unsafe UNSAFE; + + public static final int BYTE_ARRAY_OFFSET; + + public static final int INT_ARRAY_OFFSET; + + public static final int LONG_ARRAY_OFFSET; + + public static final int DOUBLE_ARRAY_OFFSET; + + /** + * Limits the number of bytes to copy per {@link Unsafe#copyMemory(long, long, long)} to + * allow safepoint polling during a large copy. + */ + private static final long UNSAFE_COPY_THRESHOLD = 1024L * 1024L; + + static { + sun.misc.Unsafe unsafe; + try { + Field unsafeField = Unsafe.class.getDeclaredField("theUnsafe"); + unsafeField.setAccessible(true); + unsafe = (sun.misc.Unsafe) unsafeField.get(null); + } catch (Throwable cause) { + unsafe = null; + } + UNSAFE = unsafe; + + if (UNSAFE != null) { + BYTE_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(byte[].class); + INT_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(int[].class); + LONG_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(long[].class); + DOUBLE_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(double[].class); + } else { + BYTE_ARRAY_OFFSET = 0; + INT_ARRAY_OFFSET = 0; + LONG_ARRAY_OFFSET = 0; + DOUBLE_ARRAY_OFFSET = 0; + } + } + + static public void copyMemory( + Object src, + long srcOffset, + Object dst, + long dstOffset, + long length) { + while (length > 0) { + long size = Math.min(length, UNSAFE_COPY_THRESHOLD); + UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); + length -= size; + srcOffset += size; + dstOffset += size; + } + } + + /** + * Raises an exception bypassing compiler checks for checked exceptions. + */ + public static void throwException(Throwable t) { + UNSAFE.throwException(t); + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java new file mode 100644 index 0000000000000..53eadf96a6b52 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.array; + +import org.apache.spark.unsafe.PlatformDependent; + +public class ByteArrayMethods { + + private ByteArrayMethods() { + // Private constructor, since this class only contains static methods. + } + + public static int roundNumberOfBytesToNearestWord(int numBytes) { + int remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8` + if (remainder == 0) { + return numBytes; + } else { + return numBytes + (8 - remainder); + } + } + + /** + * Optimized byte array equality check for 8-byte-word-aligned byte arrays. + * @return true if the arrays are equal, false otherwise + */ + public static boolean wordAlignedArrayEquals( + Object leftBaseObject, + long leftBaseOffset, + Object rightBaseObject, + long rightBaseOffset, + long arrayLengthInBytes) { + for (int i = 0; i < arrayLengthInBytes; i += 8) { + final long left = + PlatformDependent.UNSAFE.getLong(leftBaseObject, leftBaseOffset + i); + final long right = + PlatformDependent.UNSAFE.getLong(rightBaseObject, rightBaseOffset + i); + if (left != right) return false; + } + return true; + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java b/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java new file mode 100644 index 0000000000000..18d1f0d2d7eb2 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.array; + +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.MemoryBlock; + +/** + * An array of long values. Compared with native JVM arrays, this: + *
    + *
  • supports using both in-heap and off-heap memory
  • + *
  • has no bound checking, and thus can crash the JVM process when assert is turned off
  • + *
+ */ +public final class LongArray { + + // This is a long so that we perform long multiplications when computing offsets. + private static final long WIDTH = 8; + + private final MemoryBlock memory; + private final Object baseObj; + private final long baseOffset; + + private final long length; + + public LongArray(MemoryBlock memory) { + assert memory.size() % WIDTH == 0 : "Memory not aligned (" + memory.size() + ")"; + assert memory.size() < (long) Integer.MAX_VALUE * 8: "Array size > 4 billion elements"; + this.memory = memory; + this.baseObj = memory.getBaseObject(); + this.baseOffset = memory.getBaseOffset(); + this.length = memory.size() / WIDTH; + } + + public MemoryBlock memoryBlock() { + return memory; + } + + /** + * Returns the number of elements this array can hold. + */ + public long size() { + return length; + } + + /** + * Sets the value at position {@code index}. + */ + public void set(int index, long value) { + assert index >= 0 : "index (" + index + ") should >= 0"; + assert index < length : "index (" + index + ") should < length (" + length + ")"; + PlatformDependent.UNSAFE.putLong(baseObj, baseOffset + index * WIDTH, value); + } + + /** + * Returns the value at position {@code index}. + */ + public long get(int index) { + assert index >= 0 : "index (" + index + ") should >= 0"; + assert index < length : "index (" + index + ") should < length (" + length + ")"; + return PlatformDependent.UNSAFE.getLong(baseObj, baseOffset + index * WIDTH); + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java new file mode 100644 index 0000000000000..f72e07fce92fd --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.bitset; + +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; + +/** + * A fixed size uncompressed bit set backed by a {@link LongArray}. + * + * Each bit occupies exactly one bit of storage. + */ +public final class BitSet { + + /** A long array for the bits. */ + private final LongArray words; + + /** Length of the long array. */ + private final int numWords; + + private final Object baseObject; + private final long baseOffset; + + /** + * Creates a new {@link BitSet} using the specified memory block. Size of the memory block must be + * multiple of 8 bytes (i.e. 64 bits). + */ + public BitSet(MemoryBlock memory) { + words = new LongArray(memory); + assert (words.size() <= Integer.MAX_VALUE); + numWords = (int) words.size(); + baseObject = words.memoryBlock().getBaseObject(); + baseOffset = words.memoryBlock().getBaseOffset(); + } + + public MemoryBlock memoryBlock() { + return words.memoryBlock(); + } + + /** + * Returns the number of bits in this {@code BitSet}. + */ + public long capacity() { + return numWords * 64; + } + + /** + * Sets the bit at the specified index to {@code true}. + */ + public void set(int index) { + assert index < numWords * 64 : "index (" + index + ") should < length (" + numWords * 64 + ")"; + BitSetMethods.set(baseObject, baseOffset, index); + } + + /** + * Sets the bit at the specified index to {@code false}. + */ + public void unset(int index) { + assert index < numWords * 64 : "index (" + index + ") should < length (" + numWords * 64 + ")"; + BitSetMethods.unset(baseObject, baseOffset, index); + } + + /** + * Returns {@code true} if the bit is set at the specified index. + */ + public boolean isSet(int index) { + assert index < numWords * 64 : "index (" + index + ") should < length (" + numWords * 64 + ")"; + return BitSetMethods.isSet(baseObject, baseOffset, index); + } + + /** + * Returns the index of the first bit that is set to true that occurs on or after the + * specified starting index. If no such bit exists then {@code -1} is returned. + *

+ * To iterate over the true bits in a BitSet, use the following loop: + *

+   * 
+   *  for (long i = bs.nextSetBit(0); i >= 0; i = bs.nextSetBit(i + 1)) {
+   *    // operate on index i here
+   *  }
+   * 
+   * 
+ * + * @param fromIndex the index to start checking from (inclusive) + * @return the index of the next set bit, or -1 if there is no such bit + */ + public int nextSetBit(int fromIndex) { + return BitSetMethods.nextSetBit(baseObject, baseOffset, fromIndex, numWords); + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java new file mode 100644 index 0000000000000..f30626d8f4317 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.bitset; + +import org.apache.spark.unsafe.PlatformDependent; + +/** + * Methods for working with fixed-size uncompressed bitsets. + * + * We assume that the bitset data is word-aligned (that is, a multiple of 8 bytes in length). + * + * Each bit occupies exactly one bit of storage. + */ +public final class BitSetMethods { + + private static final long WORD_SIZE = 8; + + private BitSetMethods() { + // Make the default constructor private, since this only holds static methods. + } + + /** + * Sets the bit at the specified index to {@code true}. + */ + public static void set(Object baseObject, long baseOffset, int index) { + assert index >= 0 : "index (" + index + ") should >= 0"; + final long mask = 1L << (index & 0x3f); // mod 64 and shift + final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE; + final long word = PlatformDependent.UNSAFE.getLong(baseObject, wordOffset); + PlatformDependent.UNSAFE.putLong(baseObject, wordOffset, word | mask); + } + + /** + * Sets the bit at the specified index to {@code false}. + */ + public static void unset(Object baseObject, long baseOffset, int index) { + assert index >= 0 : "index (" + index + ") should >= 0"; + final long mask = 1L << (index & 0x3f); // mod 64 and shift + final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE; + final long word = PlatformDependent.UNSAFE.getLong(baseObject, wordOffset); + PlatformDependent.UNSAFE.putLong(baseObject, wordOffset, word & ~mask); + } + + /** + * Returns {@code true} if the bit is set at the specified index. + */ + public static boolean isSet(Object baseObject, long baseOffset, int index) { + assert index >= 0 : "index (" + index + ") should >= 0"; + final long mask = 1L << (index & 0x3f); // mod 64 and shift + final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE; + final long word = PlatformDependent.UNSAFE.getLong(baseObject, wordOffset); + return (word & mask) != 0; + } + + /** + * Returns {@code true} if any bit is set. + */ + public static boolean anySet(Object baseObject, long baseOffset, long bitSetWidthInBytes) { + for (int i = 0; i <= bitSetWidthInBytes; i++) { + if (PlatformDependent.UNSAFE.getByte(baseObject, baseOffset + i) != 0) { + return true; + } + } + return false; + } + + /** + * Returns the index of the first bit that is set to true that occurs on or after the + * specified starting index. If no such bit exists then {@code -1} is returned. + *

+ * To iterate over the true bits in a BitSet, use the following loop: + *

+   * 
+   *  for (long i = bs.nextSetBit(0, sizeInWords); i >= 0; i = bs.nextSetBit(i + 1, sizeInWords)) {
+   *    // operate on index i here
+   *  }
+   * 
+   * 
+ * + * @param fromIndex the index to start checking from (inclusive) + * @param bitsetSizeInWords the size of the bitset, measured in 8-byte words + * @return the index of the next set bit, or -1 if there is no such bit + */ + public static int nextSetBit( + Object baseObject, + long baseOffset, + int fromIndex, + int bitsetSizeInWords) { + int wi = fromIndex >> 6; + if (wi >= bitsetSizeInWords) { + return -1; + } + + // Try to find the next set bit in the current word + final int subIndex = fromIndex & 0x3f; + long word = + PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + wi * WORD_SIZE) >> subIndex; + if (word != 0) { + return (wi << 6) + subIndex + java.lang.Long.numberOfTrailingZeros(word); + } + + // Find the next set bit in the rest of the words + wi += 1; + while (wi < bitsetSizeInWords) { + word = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + wi * WORD_SIZE); + if (word != 0) { + return (wi << 6) + java.lang.Long.numberOfTrailingZeros(word); + } + wi += 1; + } + + return -1; + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java new file mode 100644 index 0000000000000..85cd02469adb7 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.hash; + +import org.apache.spark.unsafe.PlatformDependent; + +/** + * 32-bit Murmur3 hasher. This is based on Guava's Murmur3_32HashFunction. + */ +public final class Murmur3_x86_32 { + private static final int C1 = 0xcc9e2d51; + private static final int C2 = 0x1b873593; + + private final int seed; + + public Murmur3_x86_32(int seed) { + this.seed = seed; + } + + @Override + public String toString() { + return "Murmur3_32(seed=" + seed + ")"; + } + + public int hashInt(int input) { + int k1 = mixK1(input); + int h1 = mixH1(seed, k1); + + return fmix(h1, 4); + } + + public int hashUnsafeWords(Object baseObject, long baseOffset, int lengthInBytes) { + // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method. + assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)"; + int h1 = seed; + for (int offset = 0; offset < lengthInBytes; offset += 4) { + int halfWord = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + offset); + int k1 = mixK1(halfWord); + h1 = mixH1(h1, k1); + } + return fmix(h1, lengthInBytes); + } + + public int hashLong(long input) { + int low = (int) input; + int high = (int) (input >>> 32); + + int k1 = mixK1(low); + int h1 = mixH1(seed, k1); + + k1 = mixK1(high); + h1 = mixH1(h1, k1); + + return fmix(h1, 8); + } + + private static int mixK1(int k1) { + k1 *= C1; + k1 = Integer.rotateLeft(k1, 15); + k1 *= C2; + return k1; + } + + private static int mixH1(int h1, int k1) { + h1 ^= k1; + h1 = Integer.rotateLeft(h1, 13); + h1 = h1 * 5 + 0xe6546b64; + return h1; + } + + // Finalization mix - force all bits of a hash block to avalanche + private static int fmix(int h1, int length) { + h1 ^= length; + h1 ^= h1 >>> 16; + h1 *= 0x85ebca6b; + h1 ^= h1 >>> 13; + h1 *= 0xc2b2ae35; + h1 ^= h1 >>> 16; + return h1; + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java new file mode 100644 index 0000000000000..85b64c0833803 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -0,0 +1,549 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.map; + +import java.lang.Override; +import java.lang.UnsupportedOperationException; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; + +import org.apache.spark.unsafe.*; +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.bitset.BitSet; +import org.apache.spark.unsafe.hash.Murmur3_x86_32; +import org.apache.spark.unsafe.memory.*; + +/** + * An append-only hash map where keys and values are contiguous regions of bytes. + *

+ * This is backed by a power-of-2-sized hash table, using quadratic probing with triangular numbers, + * which is guaranteed to exhaust the space. + *

+ * The map can support up to 2^31 keys because we use 32 bit MurmurHash. If the key cardinality is + * higher than this, you should probably be using sorting instead of hashing for better cache + * locality. + *

+ * This class is not thread safe. + */ +public final class BytesToBytesMap { + + private static final Murmur3_x86_32 HASHER = new Murmur3_x86_32(0); + + private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING; + + private final TaskMemoryManager memoryManager; + + /** + * A linked list for tracking all allocated data pages so that we can free all of our memory. + */ + private final List dataPages = new LinkedList(); + + /** + * The data page that will be used to store keys and values for new hashtable entries. When this + * page becomes full, a new page will be allocated and this pointer will change to point to that + * new page. + */ + private MemoryBlock currentDataPage = null; + + /** + * Offset into `currentDataPage` that points to the location where new data can be inserted into + * the page. + */ + private long pageCursor = 0; + + /** + * The size of the data pages that hold key and value data. Map entries cannot span multiple + * pages, so this limits the maximum entry size. + */ + private static final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes + + // This choice of page table size and page size means that we can address up to 500 gigabytes + // of memory. + + /** + * A single array to store the key and value. + * + * Position {@code 2 * i} in the array is used to track a pointer to the key at index {@code i}, + * while position {@code 2 * i + 1} in the array holds key's full 32-bit hashcode. + */ + private LongArray longArray; + // TODO: we're wasting 32 bits of space here; we can probably store fewer bits of the hashcode + // and exploit word-alignment to use fewer bits to hold the address. This might let us store + // only one long per map entry, increasing the chance that this array will fit in cache at the + // expense of maybe performing more lookups if we have hash collisions. Say that we stored only + // 27 bits of the hashcode and 37 bits of the address. 37 bits is enough to address 1 terabyte + // of RAM given word-alignment. If we use 13 bits of this for our page table, that gives us a + // maximum page size of 2^24 * 8 = ~134 megabytes per page. This change will require us to store + // full base addresses in the page table for off-heap mode so that we can reconstruct the full + // absolute memory addresses. + + /** + * A {@link BitSet} used to track location of the map where the key is set. + * Size of the bitset should be half of the size of the long array. + */ + private BitSet bitset; + + private final double loadFactor; + + /** + * Number of keys defined in the map. + */ + private int size; + + /** + * The map will be expanded once the number of keys exceeds this threshold. + */ + private int growthThreshold; + + /** + * Mask for truncating hashcodes so that they do not exceed the long array's size. + * This is a strength reduction optimization; we're essentially performing a modulus operation, + * but doing so with a bitmask because this is a power-of-2-sized hash map. + */ + private int mask; + + /** + * Return value of {@link BytesToBytesMap#lookup(Object, long, int)}. + */ + private final Location loc; + + private final boolean enablePerfMetrics; + + private long timeSpentResizingNs = 0; + + private long numProbes = 0; + + private long numKeyLookups = 0; + + private long numHashCollisions = 0; + + public BytesToBytesMap( + TaskMemoryManager memoryManager, + int initialCapacity, + double loadFactor, + boolean enablePerfMetrics) { + this.memoryManager = memoryManager; + this.loadFactor = loadFactor; + this.loc = new Location(); + this.enablePerfMetrics = enablePerfMetrics; + allocate(initialCapacity); + } + + public BytesToBytesMap(TaskMemoryManager memoryManager, int initialCapacity) { + this(memoryManager, initialCapacity, 0.70, false); + } + + public BytesToBytesMap( + TaskMemoryManager memoryManager, + int initialCapacity, + boolean enablePerfMetrics) { + this(memoryManager, initialCapacity, 0.70, enablePerfMetrics); + } + + /** + * Returns the number of keys defined in the map. + */ + public int size() { return size; } + + /** + * Returns an iterator for iterating over the entries of this map. + * + * For efficiency, all calls to `next()` will return the same {@link Location} object. + * + * If any other lookups or operations are performed on this map while iterating over it, including + * `lookup()`, the behavior of the returned iterator is undefined. + */ + public Iterator iterator() { + return new Iterator() { + + private int nextPos = bitset.nextSetBit(0); + + @Override + public boolean hasNext() { + return nextPos != -1; + } + + @Override + public Location next() { + final int pos = nextPos; + nextPos = bitset.nextSetBit(nextPos + 1); + return loc.with(pos, 0, true); + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } + + /** + * Looks up a key, and return a {@link Location} handle that can be used to test existence + * and read/write values. + * + * This function always return the same {@link Location} instance to avoid object allocation. + */ + public Location lookup( + Object keyBaseObject, + long keyBaseOffset, + int keyRowLengthBytes) { + if (enablePerfMetrics) { + numKeyLookups++; + } + final int hashcode = HASHER.hashUnsafeWords(keyBaseObject, keyBaseOffset, keyRowLengthBytes); + int pos = hashcode & mask; + int step = 1; + while (true) { + if (enablePerfMetrics) { + numProbes++; + } + if (!bitset.isSet(pos)) { + // This is a new key. + return loc.with(pos, hashcode, false); + } else { + long stored = longArray.get(pos * 2 + 1); + if ((int) (stored) == hashcode) { + // Full hash code matches. Let's compare the keys for equality. + loc.with(pos, hashcode, true); + if (loc.getKeyLength() == keyRowLengthBytes) { + final MemoryLocation keyAddress = loc.getKeyAddress(); + final Object storedKeyBaseObject = keyAddress.getBaseObject(); + final long storedKeyBaseOffset = keyAddress.getBaseOffset(); + final boolean areEqual = ByteArrayMethods.wordAlignedArrayEquals( + keyBaseObject, + keyBaseOffset, + storedKeyBaseObject, + storedKeyBaseOffset, + keyRowLengthBytes + ); + if (areEqual) { + return loc; + } else { + if (enablePerfMetrics) { + numHashCollisions++; + } + } + } + } + } + pos = (pos + step) & mask; + step++; + } + } + + /** + * Handle returned by {@link BytesToBytesMap#lookup(Object, long, int)} function. + */ + public final class Location { + /** An index into the hash map's Long array */ + private int pos; + /** True if this location points to a position where a key is defined, false otherwise */ + private boolean isDefined; + /** + * The hashcode of the most recent key passed to + * {@link BytesToBytesMap#lookup(Object, long, int)}. Caching this hashcode here allows us to + * avoid re-hashing the key when storing a value for that key. + */ + private int keyHashcode; + private final MemoryLocation keyMemoryLocation = new MemoryLocation(); + private final MemoryLocation valueMemoryLocation = new MemoryLocation(); + private int keyLength; + private int valueLength; + + private void updateAddressesAndSizes(long fullKeyAddress) { + final Object page = memoryManager.getPage(fullKeyAddress); + final long keyOffsetInPage = memoryManager.getOffsetInPage(fullKeyAddress); + long position = keyOffsetInPage; + keyLength = (int) PlatformDependent.UNSAFE.getLong(page, position); + position += 8; // word used to store the key size + keyMemoryLocation.setObjAndOffset(page, position); + position += keyLength; + valueLength = (int) PlatformDependent.UNSAFE.getLong(page, position); + position += 8; // word used to store the key size + valueMemoryLocation.setObjAndOffset(page, position); + } + + Location with(int pos, int keyHashcode, boolean isDefined) { + this.pos = pos; + this.isDefined = isDefined; + this.keyHashcode = keyHashcode; + if (isDefined) { + final long fullKeyAddress = longArray.get(pos * 2); + updateAddressesAndSizes(fullKeyAddress); + } + return this; + } + + /** + * Returns true if the key is defined at this position, and false otherwise. + */ + public boolean isDefined() { + return isDefined; + } + + /** + * Returns the address of the key defined at this position. + * This points to the first byte of the key data. + * Unspecified behavior if the key is not defined. + * For efficiency reasons, calls to this method always returns the same MemoryLocation object. + */ + public MemoryLocation getKeyAddress() { + assert (isDefined); + return keyMemoryLocation; + } + + /** + * Returns the length of the key defined at this position. + * Unspecified behavior if the key is not defined. + */ + public int getKeyLength() { + assert (isDefined); + return keyLength; + } + + /** + * Returns the address of the value defined at this position. + * This points to the first byte of the value data. + * Unspecified behavior if the key is not defined. + * For efficiency reasons, calls to this method always returns the same MemoryLocation object. + */ + public MemoryLocation getValueAddress() { + assert (isDefined); + return valueMemoryLocation; + } + + /** + * Returns the length of the value defined at this position. + * Unspecified behavior if the key is not defined. + */ + public int getValueLength() { + assert (isDefined); + return valueLength; + } + + /** + * Store a new key and value. This method may only be called once for a given key; if you want + * to update the value associated with a key, then you can directly manipulate the bytes stored + * at the value address. + *

+ * It is only valid to call this method immediately after calling `lookup()` using the same key. + *

+ * After calling this method, calls to `get[Key|Value]Address()` and `get[Key|Value]Length` + * will return information on the data stored by this `putNewKey` call. + *

+ * As an example usage, here's the proper way to store a new key: + *

+ *

+     *   Location loc = map.lookup(keyBaseOffset, keyBaseObject, keyLengthInBytes);
+     *   if (!loc.isDefined()) {
+     *     loc.putNewKey(keyBaseOffset, keyBaseObject, keyLengthInBytes, ...)
+     *   }
+     * 
+ *

+ * Unspecified behavior if the key is not defined. + */ + public void putNewKey( + Object keyBaseObject, + long keyBaseOffset, + int keyLengthBytes, + Object valueBaseObject, + long valueBaseOffset, + int valueLengthBytes) { + assert (!isDefined) : "Can only set value once for a key"; + isDefined = true; + assert (keyLengthBytes % 8 == 0); + assert (valueLengthBytes % 8 == 0); + // Here, we'll copy the data into our data pages. Because we only store a relative offset from + // the key address instead of storing the absolute address of the value, the key and value + // must be stored in the same memory page. + // (8 byte key length) (key) (8 byte value length) (value) + final long requiredSize = 8 + keyLengthBytes + 8 + valueLengthBytes; + assert(requiredSize <= PAGE_SIZE_BYTES); + size++; + bitset.set(pos); + + // If there's not enough space in the current page, allocate a new page: + if (currentDataPage == null || PAGE_SIZE_BYTES - pageCursor < requiredSize) { + MemoryBlock newPage = memoryManager.allocatePage(PAGE_SIZE_BYTES); + dataPages.add(newPage); + pageCursor = 0; + currentDataPage = newPage; + } + + // Compute all of our offsets up-front: + final Object pageBaseObject = currentDataPage.getBaseObject(); + final long pageBaseOffset = currentDataPage.getBaseOffset(); + final long keySizeOffsetInPage = pageBaseOffset + pageCursor; + pageCursor += 8; // word used to store the key size + final long keyDataOffsetInPage = pageBaseOffset + pageCursor; + pageCursor += keyLengthBytes; + final long valueSizeOffsetInPage = pageBaseOffset + pageCursor; + pageCursor += 8; // word used to store the value size + final long valueDataOffsetInPage = pageBaseOffset + pageCursor; + pageCursor += valueLengthBytes; + + // Copy the key + PlatformDependent.UNSAFE.putLong(pageBaseObject, keySizeOffsetInPage, keyLengthBytes); + PlatformDependent.UNSAFE.copyMemory( + keyBaseObject, keyBaseOffset, pageBaseObject, keyDataOffsetInPage, keyLengthBytes); + // Copy the value + PlatformDependent.UNSAFE.putLong(pageBaseObject, valueSizeOffsetInPage, valueLengthBytes); + PlatformDependent.UNSAFE.copyMemory( + valueBaseObject, valueBaseOffset, pageBaseObject, valueDataOffsetInPage, valueLengthBytes); + + final long storedKeyAddress = memoryManager.encodePageNumberAndOffset( + currentDataPage, keySizeOffsetInPage); + longArray.set(pos * 2, storedKeyAddress); + longArray.set(pos * 2 + 1, keyHashcode); + updateAddressesAndSizes(storedKeyAddress); + isDefined = true; + if (size > growthThreshold) { + growAndRehash(); + } + } + } + + /** + * Allocate new data structures for this map. When calling this outside of the constructor, + * make sure to keep references to the old data structures so that you can free them. + * + * @param capacity the new map capacity + */ + private void allocate(int capacity) { + capacity = Math.max((int) Math.min(Integer.MAX_VALUE, nextPowerOf2(capacity)), 64); + longArray = new LongArray(memoryManager.allocate(capacity * 8 * 2)); + bitset = new BitSet(memoryManager.allocate(capacity / 8).zero()); + + this.growthThreshold = (int) (capacity * loadFactor); + this.mask = capacity - 1; + } + + /** + * Free all allocated memory associated with this map, including the storage for keys and values + * as well as the hash map array itself. + * + * This method is idempotent. + */ + public void free() { + if (longArray != null) { + memoryManager.free(longArray.memoryBlock()); + longArray = null; + } + if (bitset != null) { + memoryManager.free(bitset.memoryBlock()); + bitset = null; + } + Iterator dataPagesIterator = dataPages.iterator(); + while (dataPagesIterator.hasNext()) { + memoryManager.freePage(dataPagesIterator.next()); + dataPagesIterator.remove(); + } + assert(dataPages.isEmpty()); + } + + /** Returns the total amount of memory, in bytes, consumed by this map's managed structures. */ + public long getTotalMemoryConsumption() { + return ( + dataPages.size() * PAGE_SIZE_BYTES + + bitset.memoryBlock().size() + + longArray.memoryBlock().size()); + } + + /** + * Returns the total amount of time spent resizing this map (in nanoseconds). + */ + public long getTimeSpentResizingNs() { + if (!enablePerfMetrics) { + throw new IllegalStateException(); + } + return timeSpentResizingNs; + } + + + /** + * Returns the average number of probes per key lookup. + */ + public double getAverageProbesPerLookup() { + if (!enablePerfMetrics) { + throw new IllegalStateException(); + } + return (1.0 * numProbes) / numKeyLookups; + } + + public long getNumHashCollisions() { + if (!enablePerfMetrics) { + throw new IllegalStateException(); + } + return numHashCollisions; + } + + /** + * Grows the size of the hash table and re-hash everything. + */ + private void growAndRehash() { + long resizeStartTime = -1; + if (enablePerfMetrics) { + resizeStartTime = System.nanoTime(); + } + // Store references to the old data structures to be used when we re-hash + final LongArray oldLongArray = longArray; + final BitSet oldBitSet = bitset; + final int oldCapacity = (int) oldBitSet.capacity(); + + // Allocate the new data structures + allocate(Math.min(Integer.MAX_VALUE, growthStrategy.nextCapacity(oldCapacity))); + + // Re-mask (we don't recompute the hashcode because we stored all 32 bits of it) + for (int pos = oldBitSet.nextSetBit(0); pos >= 0; pos = oldBitSet.nextSetBit(pos + 1)) { + final long keyPointer = oldLongArray.get(pos * 2); + final int hashcode = (int) oldLongArray.get(pos * 2 + 1); + int newPos = hashcode & mask; + int step = 1; + boolean keepGoing = true; + + // No need to check for equality here when we insert so this has one less if branch than + // the similar code path in addWithoutResize. + while (keepGoing) { + if (!bitset.isSet(newPos)) { + bitset.set(newPos); + longArray.set(newPos * 2, keyPointer); + longArray.set(newPos * 2 + 1, hashcode); + keepGoing = false; + } else { + newPos = (newPos + step) & mask; + step++; + } + } + } + + // Deallocate the old data structures. + memoryManager.free(oldLongArray.memoryBlock()); + memoryManager.free(oldBitSet.memoryBlock()); + if (enablePerfMetrics) { + timeSpentResizingNs += System.nanoTime() - resizeStartTime; + } + } + + /** Returns the next number greater or equal num that is power of 2. */ + private static long nextPowerOf2(long num) { + final long highBit = Long.highestOneBit(num); + return (highBit == num) ? num : highBit << 1; + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java b/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java new file mode 100644 index 0000000000000..7c321baffe82d --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.map; + +/** + * Interface that defines how we can grow the size of a hash map when it is over a threshold. + */ +public interface HashMapGrowthStrategy { + + int nextCapacity(int currentCapacity); + + /** + * Double the size of the hash map every time. + */ + HashMapGrowthStrategy DOUBLING = new Doubling(); + + class Doubling implements HashMapGrowthStrategy { + @Override + public int nextCapacity(int currentCapacity) { + return currentCapacity * 2; + } + } + +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java new file mode 100644 index 0000000000000..62c29c8cc1e4d --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +/** + * Manages memory for an executor. Individual operators / tasks allocate memory through + * {@link TaskMemoryManager} objects, which obtain their memory from ExecutorMemoryManager. + */ +public class ExecutorMemoryManager { + + /** + * Allocator, exposed for enabling untracked allocations of temporary data structures. + */ + public final MemoryAllocator allocator; + + /** + * Tracks whether memory will be allocated on the JVM heap or off-heap using sun.misc.Unsafe. + */ + final boolean inHeap; + + /** + * Construct a new ExecutorMemoryManager. + * + * @param allocator the allocator that will be used + */ + public ExecutorMemoryManager(MemoryAllocator allocator) { + this.inHeap = allocator instanceof HeapMemoryAllocator; + this.allocator = allocator; + } + + /** + * Allocates a contiguous block of memory. Note that the allocated memory is not guaranteed + * to be zeroed out (call `zero()` on the result if this is necessary). + */ + MemoryBlock allocate(long size) throws OutOfMemoryError { + return allocator.allocate(size); + } + + void free(MemoryBlock memory) { + allocator.free(memory); + } + +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java new file mode 100644 index 0000000000000..bbe83d36cf36b --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +/** + * A simple {@link MemoryAllocator} that can allocate up to 16GB using a JVM long primitive array. + */ +public class HeapMemoryAllocator implements MemoryAllocator { + + @Override + public MemoryBlock allocate(long size) throws OutOfMemoryError { + long[] array = new long[(int) (size / 8)]; + return MemoryBlock.fromLongArray(array); + } + + @Override + public void free(MemoryBlock memory) { + // Do nothing + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java new file mode 100644 index 0000000000000..5192f68c862cf --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +public interface MemoryAllocator { + + /** + * Allocates a contiguous block of memory. Note that the allocated memory is not guaranteed + * to be zeroed out (call `zero()` on the result if this is necessary). + */ + MemoryBlock allocate(long size) throws OutOfMemoryError; + + void free(MemoryBlock memory); + + MemoryAllocator UNSAFE = new UnsafeMemoryAllocator(); + + MemoryAllocator HEAP = new HeapMemoryAllocator(); +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java new file mode 100644 index 0000000000000..0beb743e5644e --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import javax.annotation.Nullable; + +import org.apache.spark.unsafe.PlatformDependent; + +/** + * A consecutive block of memory, starting at a {@link MemoryLocation} with a fixed size. + */ +public class MemoryBlock extends MemoryLocation { + + private final long length; + + /** + * Optional page number; used when this MemoryBlock represents a page allocated by a + * MemoryManager. This is package-private and is modified by MemoryManager. + */ + int pageNumber = -1; + + MemoryBlock(@Nullable Object obj, long offset, long length) { + super(obj, offset); + this.length = length; + } + + /** + * Returns the size of the memory block. + */ + public long size() { + return length; + } + + /** + * Clear the contents of this memory block. Returns `this` to facilitate chaining. + */ + public MemoryBlock zero() { + PlatformDependent.UNSAFE.setMemory(obj, offset, length, (byte) 0); + return this; + } + + /** + * Creates a memory block pointing to the memory used by the long array. + */ + public static MemoryBlock fromLongArray(final long[] array) { + return new MemoryBlock(array, PlatformDependent.LONG_ARRAY_OFFSET, array.length * 8); + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java new file mode 100644 index 0000000000000..74ebc87dc978c --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import javax.annotation.Nullable; + +/** + * A memory location. Tracked either by a memory address (with off-heap allocation), + * or by an offset from a JVM object (in-heap allocation). + */ +public class MemoryLocation { + + @Nullable + Object obj; + + long offset; + + public MemoryLocation(@Nullable Object obj, long offset) { + this.obj = obj; + this.offset = offset; + } + + public MemoryLocation() { + this(null, 0); + } + + public void setObjAndOffset(Object newObj, long newOffset) { + this.obj = newObj; + this.offset = newOffset; + } + + public final Object getBaseObject() { + return obj; + } + + public final long getBaseOffset() { + return offset; + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java new file mode 100644 index 0000000000000..9224988e6ad69 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import java.util.*; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Manages the memory allocated by an individual task. + *

+ * Most of the complexity in this class deals with encoding of off-heap addresses into 64-bit longs. + * In off-heap mode, memory can be directly addressed with 64-bit longs. In on-heap mode, memory is + * addressed by the combination of a base Object reference and a 64-bit offset within that object. + * This is a problem when we want to store pointers to data structures inside of other structures, + * such as record pointers inside hashmaps or sorting buffers. Even if we decided to use 128 bits + * to address memory, we can't just store the address of the base object since it's not guaranteed + * to remain stable as the heap gets reorganized due to GC. + *

+ * Instead, we use the following approach to encode record pointers in 64-bit longs: for off-heap + * mode, just store the raw address, and for on-heap mode use the upper 13 bits of the address to + * store a "page number" and the lower 51 bits to store an offset within this page. These page + * numbers are used to index into a "page table" array inside of the MemoryManager in order to + * retrieve the base object. + *

+ * This allows us to address 8192 pages. In on-heap mode, the maximum page size is limited by the + * maximum size of a long[] array, allowing us to address 8192 * 2^32 * 8 bytes, which is + * approximately 35 terabytes of memory. + */ +public final class TaskMemoryManager { + + private final Logger logger = LoggerFactory.getLogger(TaskMemoryManager.class); + + /** + * The number of entries in the page table. + */ + private static final int PAGE_TABLE_SIZE = 1 << 13; + + /** Bit mask for the lower 51 bits of a long. */ + private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL; + + /** Bit mask for the upper 13 bits of a long */ + private static final long MASK_LONG_UPPER_13_BITS = ~MASK_LONG_LOWER_51_BITS; + + /** + * Similar to an operating system's page table, this array maps page numbers into base object + * pointers, allowing us to translate between the hashtable's internal 64-bit address + * representation and the baseObject+offset representation which we use to support both in- and + * off-heap addresses. When using an off-heap allocator, every entry in this map will be `null`. + * When using an in-heap allocator, the entries in this map will point to pages' base objects. + * Entries are added to this map as new data pages are allocated. + */ + private final MemoryBlock[] pageTable = new MemoryBlock[PAGE_TABLE_SIZE]; + + /** + * Bitmap for tracking free pages. + */ + private final BitSet allocatedPages = new BitSet(PAGE_TABLE_SIZE); + + /** + * Tracks memory allocated with {@link TaskMemoryManager#allocate(long)}, used to detect / clean + * up leaked memory. + */ + private final HashSet allocatedNonPageMemory = new HashSet(); + + private final ExecutorMemoryManager executorMemoryManager; + + /** + * Tracks whether we're in-heap or off-heap. For off-heap, we short-circuit most of these methods + * without doing any masking or lookups. Since this branching should be well-predicted by the JIT, + * this extra layer of indirection / abstraction hopefully shouldn't be too expensive. + */ + private final boolean inHeap; + + /** + * Construct a new MemoryManager. + */ + public TaskMemoryManager(ExecutorMemoryManager executorMemoryManager) { + this.inHeap = executorMemoryManager.inHeap; + this.executorMemoryManager = executorMemoryManager; + } + + /** + * Allocate a block of memory that will be tracked in the MemoryManager's page table; this is + * intended for allocating large blocks of memory that will be shared between operators. + */ + public MemoryBlock allocatePage(long size) { + if (logger.isTraceEnabled()) { + logger.trace("Allocating {} byte page", size); + } + if (size >= (1L << 51)) { + throw new IllegalArgumentException("Cannot allocate a page with more than 2^51 bytes"); + } + + final int pageNumber; + synchronized (this) { + pageNumber = allocatedPages.nextClearBit(0); + if (pageNumber >= PAGE_TABLE_SIZE) { + throw new IllegalStateException( + "Have already allocated a maximum of " + PAGE_TABLE_SIZE + " pages"); + } + allocatedPages.set(pageNumber); + } + final MemoryBlock page = executorMemoryManager.allocate(size); + page.pageNumber = pageNumber; + pageTable[pageNumber] = page; + if (logger.isDebugEnabled()) { + logger.debug("Allocate page number {} ({} bytes)", pageNumber, size); + } + return page; + } + + /** + * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage(long)}. + */ + public void freePage(MemoryBlock page) { + if (logger.isTraceEnabled()) { + logger.trace("Freeing page number {} ({} bytes)", page.pageNumber, page.size()); + } + assert (page.pageNumber != -1) : + "Called freePage() on memory that wasn't allocated with allocatePage()"; + executorMemoryManager.free(page); + synchronized (this) { + allocatedPages.clear(page.pageNumber); + } + pageTable[page.pageNumber] = null; + if (logger.isDebugEnabled()) { + logger.debug("Freed page number {} ({} bytes)", page.pageNumber, page.size()); + } + } + + /** + * Allocates a contiguous block of memory. Note that the allocated memory is not guaranteed + * to be zeroed out (call `zero()` on the result if this is necessary). This method is intended + * to be used for allocating operators' internal data structures. For data pages that you want to + * exchange between operators, consider using {@link TaskMemoryManager#allocatePage(long)}, since + * that will enable intra-memory pointers (see + * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} and this class's + * top-level Javadoc for more details). + */ + public MemoryBlock allocate(long size) throws OutOfMemoryError { + final MemoryBlock memory = executorMemoryManager.allocate(size); + allocatedNonPageMemory.add(memory); + return memory; + } + + /** + * Free memory allocated by {@link TaskMemoryManager#allocate(long)}. + */ + public void free(MemoryBlock memory) { + assert (memory.pageNumber == -1) : "Should call freePage() for pages, not free()"; + executorMemoryManager.free(memory); + final boolean wasAlreadyRemoved = !allocatedNonPageMemory.remove(memory); + assert (!wasAlreadyRemoved) : "Called free() on memory that was already freed!"; + } + + /** + * Given a memory page and offset within that page, encode this address into a 64-bit long. + * This address will remain valid as long as the corresponding page has not been freed. + */ + public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { + if (inHeap) { + assert (page.pageNumber != -1) : "encodePageNumberAndOffset called with invalid page"; + return (((long) page.pageNumber) << 51) | (offsetInPage & MASK_LONG_LOWER_51_BITS); + } else { + return offsetInPage; + } + } + + /** + * Get the page associated with an address encoded by + * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} + */ + public Object getPage(long pagePlusOffsetAddress) { + if (inHeap) { + final int pageNumber = (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> 51); + assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); + final Object page = pageTable[pageNumber].getBaseObject(); + assert (page != null); + return page; + } else { + return null; + } + } + + /** + * Get the offset associated with an address encoded by + * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} + */ + public long getOffsetInPage(long pagePlusOffsetAddress) { + if (inHeap) { + return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS); + } else { + return pagePlusOffsetAddress; + } + } + + /** + * Clean up all allocated memory and pages. Returns the number of bytes freed. A non-zero return + * value can be used to detect memory leaks. + */ + public long cleanUpAllAllocatedMemory() { + long freedBytes = 0; + for (MemoryBlock page : pageTable) { + if (page != null) { + freedBytes += page.size(); + freePage(page); + } + } + final Iterator iter = allocatedNonPageMemory.iterator(); + while (iter.hasNext()) { + final MemoryBlock memory = iter.next(); + freedBytes += memory.size(); + // We don't call free() here because that calls Set.remove, which would lead to a + // ConcurrentModificationException here. + executorMemoryManager.free(memory); + iter.remove(); + } + return freedBytes; + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java new file mode 100644 index 0000000000000..15898771fef25 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import org.apache.spark.unsafe.PlatformDependent; + +/** + * A simple {@link MemoryAllocator} that uses {@code Unsafe} to allocate off-heap memory. + */ +public class UnsafeMemoryAllocator implements MemoryAllocator { + + @Override + public MemoryBlock allocate(long size) throws OutOfMemoryError { + long address = PlatformDependent.UNSAFE.allocateMemory(size); + return new MemoryBlock(null, address, size); + } + + @Override + public void free(MemoryBlock memory) { + assert (memory.obj == null) : + "baseObject not null; are you trying to use the off-heap allocator to free on-heap memory?"; + PlatformDependent.UNSAFE.freeMemory(memory.offset); + } +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java new file mode 100644 index 0000000000000..5974cf91ff993 --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.array; + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.spark.unsafe.memory.MemoryBlock; + +public class LongArraySuite { + + @Test + public void basicTest() { + long[] bytes = new long[2]; + LongArray arr = new LongArray(MemoryBlock.fromLongArray(bytes)); + arr.set(0, 1L); + arr.set(1, 2L); + arr.set(1, 3L); + Assert.assertEquals(2, arr.size()); + Assert.assertEquals(1L, arr.get(0)); + Assert.assertEquals(3L, arr.get(1)); + } +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java new file mode 100644 index 0000000000000..4bf132fd4053e --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.bitset; + +import junit.framework.Assert; +import org.apache.spark.unsafe.bitset.BitSet; +import org.junit.Test; + +import org.apache.spark.unsafe.memory.MemoryBlock; + +public class BitSetSuite { + + private static BitSet createBitSet(int capacity) { + assert capacity % 64 == 0; + return new BitSet(MemoryBlock.fromLongArray(new long[capacity / 64]).zero()); + } + + @Test + public void basicOps() { + BitSet bs = createBitSet(64); + Assert.assertEquals(64, bs.capacity()); + + // Make sure the bit set starts empty. + for (int i = 0; i < bs.capacity(); i++) { + Assert.assertFalse(bs.isSet(i)); + } + + // Set every bit and check it. + for (int i = 0; i < bs.capacity(); i++) { + bs.set(i); + Assert.assertTrue(bs.isSet(i)); + } + + // Unset every bit and check it. + for (int i = 0; i < bs.capacity(); i++) { + Assert.assertTrue(bs.isSet(i)); + bs.unset(i); + Assert.assertFalse(bs.isSet(i)); + } + } + + @Test + public void traversal() { + BitSet bs = createBitSet(256); + + Assert.assertEquals(-1, bs.nextSetBit(0)); + Assert.assertEquals(-1, bs.nextSetBit(10)); + Assert.assertEquals(-1, bs.nextSetBit(64)); + + bs.set(10); + Assert.assertEquals(10, bs.nextSetBit(0)); + Assert.assertEquals(10, bs.nextSetBit(1)); + Assert.assertEquals(10, bs.nextSetBit(10)); + Assert.assertEquals(-1, bs.nextSetBit(11)); + + bs.set(11); + Assert.assertEquals(10, bs.nextSetBit(10)); + Assert.assertEquals(11, bs.nextSetBit(11)); + + // Skip a whole word and find it + bs.set(190); + Assert.assertEquals(190, bs.nextSetBit(12)); + + Assert.assertEquals(-1, bs.nextSetBit(191)); + Assert.assertEquals(-1, bs.nextSetBit(256)); + } +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java b/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java new file mode 100644 index 0000000000000..3b9175835229c --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.hash; + +import java.util.HashSet; +import java.util.Random; +import java.util.Set; + +import junit.framework.Assert; +import org.apache.spark.unsafe.PlatformDependent; +import org.junit.Test; + +/** + * Test file based on Guava's Murmur3Hash32Test. + */ +public class Murmur3_x86_32Suite { + + private static final Murmur3_x86_32 hasher = new Murmur3_x86_32(0); + + @Test + public void testKnownIntegerInputs() { + Assert.assertEquals(593689054, hasher.hashInt(0)); + Assert.assertEquals(-189366624, hasher.hashInt(-42)); + Assert.assertEquals(-1134849565, hasher.hashInt(42)); + Assert.assertEquals(-1718298732, hasher.hashInt(Integer.MIN_VALUE)); + Assert.assertEquals(-1653689534, hasher.hashInt(Integer.MAX_VALUE)); + } + + @Test + public void testKnownLongInputs() { + Assert.assertEquals(1669671676, hasher.hashLong(0L)); + Assert.assertEquals(-846261623, hasher.hashLong(-42L)); + Assert.assertEquals(1871679806, hasher.hashLong(42L)); + Assert.assertEquals(1366273829, hasher.hashLong(Long.MIN_VALUE)); + Assert.assertEquals(-2106506049, hasher.hashLong(Long.MAX_VALUE)); + } + + @Test + public void randomizedStressTest() { + int size = 65536; + Random rand = new Random(); + + // A set used to track collision rate. + Set hashcodes = new HashSet(); + for (int i = 0; i < size; i++) { + int vint = rand.nextInt(); + long lint = rand.nextLong(); + Assert.assertEquals(hasher.hashInt(vint), hasher.hashInt(vint)); + Assert.assertEquals(hasher.hashLong(lint), hasher.hashLong(lint)); + + hashcodes.add(hasher.hashLong(lint)); + } + + // A very loose bound. + Assert.assertTrue(hashcodes.size() > size * 0.95); + } + + @Test + public void randomizedStressTestBytes() { + int size = 65536; + Random rand = new Random(); + + // A set used to track collision rate. + Set hashcodes = new HashSet(); + for (int i = 0; i < size; i++) { + int byteArrSize = rand.nextInt(100) * 8; + byte[] bytes = new byte[byteArrSize]; + rand.nextBytes(bytes); + + Assert.assertEquals( + hasher.hashUnsafeWords(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize), + hasher.hashUnsafeWords(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + + hashcodes.add(hasher.hashUnsafeWords( + bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + } + + // A very loose bound. + Assert.assertTrue(hashcodes.size() > size * 0.95); + } + + @Test + public void randomizedStressTestPaddedStrings() { + int size = 64000; + // A set used to track collision rate. + Set hashcodes = new HashSet(); + for (int i = 0; i < size; i++) { + int byteArrSize = 8; + byte[] strBytes = ("" + i).getBytes(); + byte[] paddedBytes = new byte[byteArrSize]; + System.arraycopy(strBytes, 0, paddedBytes, 0, strBytes.length); + + Assert.assertEquals( + hasher.hashUnsafeWords(paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize), + hasher.hashUnsafeWords(paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + + hashcodes.add(hasher.hashUnsafeWords( + paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + } + + // A very loose bound. + Assert.assertTrue(hashcodes.size() > size * 0.95); + } +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java new file mode 100644 index 0000000000000..9038cf567f1e2 --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.map; + +import java.lang.Exception; +import java.nio.ByteBuffer; +import java.util.*; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.PlatformDependent; +import static org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET; +import org.apache.spark.unsafe.memory.ExecutorMemoryManager; +import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.unsafe.memory.MemoryLocation; +import org.apache.spark.unsafe.memory.TaskMemoryManager; + +public abstract class AbstractBytesToBytesMapSuite { + + private final Random rand = new Random(42); + + private TaskMemoryManager memoryManager; + + @Before + public void setup() { + memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(getMemoryAllocator())); + } + + @After + public void tearDown() { + if (memoryManager != null) { + memoryManager.cleanUpAllAllocatedMemory(); + memoryManager = null; + } + } + + protected abstract MemoryAllocator getMemoryAllocator(); + + private static byte[] getByteArray(MemoryLocation loc, int size) { + final byte[] arr = new byte[size]; + PlatformDependent.UNSAFE.copyMemory( + loc.getBaseObject(), + loc.getBaseOffset(), + arr, + BYTE_ARRAY_OFFSET, + size + ); + return arr; + } + + private byte[] getRandomByteArray(int numWords) { + Assert.assertTrue(numWords > 0); + final int lengthInBytes = numWords * 8; + final byte[] bytes = new byte[lengthInBytes]; + rand.nextBytes(bytes); + return bytes; + } + + /** + * Fast equality checking for byte arrays, since these comparisons are a bottleneck + * in our stress tests. + */ + private static boolean arrayEquals( + byte[] expected, + MemoryLocation actualAddr, + long actualLengthBytes) { + return (actualLengthBytes == expected.length) && ByteArrayMethods.wordAlignedArrayEquals( + expected, + BYTE_ARRAY_OFFSET, + actualAddr.getBaseObject(), + actualAddr.getBaseOffset(), + expected.length + ); + } + + @Test + public void emptyMap() { + BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64); + try { + Assert.assertEquals(0, map.size()); + final int keyLengthInWords = 10; + final int keyLengthInBytes = keyLengthInWords * 8; + final byte[] key = getRandomByteArray(keyLengthInWords); + Assert.assertFalse(map.lookup(key, BYTE_ARRAY_OFFSET, keyLengthInBytes).isDefined()); + } finally { + map.free(); + } + } + + @Test + public void setAndRetrieveAKey() { + BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64); + final int recordLengthWords = 10; + final int recordLengthBytes = recordLengthWords * 8; + final byte[] keyData = getRandomByteArray(recordLengthWords); + final byte[] valueData = getRandomByteArray(recordLengthWords); + try { + final BytesToBytesMap.Location loc = + map.lookup(keyData, BYTE_ARRAY_OFFSET, recordLengthBytes); + Assert.assertFalse(loc.isDefined()); + loc.putNewKey( + keyData, + BYTE_ARRAY_OFFSET, + recordLengthBytes, + valueData, + BYTE_ARRAY_OFFSET, + recordLengthBytes + ); + // After storing the key and value, the other location methods should return results that + // reflect the result of this store without us having to call lookup() again on the same key. + Assert.assertEquals(recordLengthBytes, loc.getKeyLength()); + Assert.assertEquals(recordLengthBytes, loc.getValueLength()); + Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes)); + Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes)); + + // After calling lookup() the location should still point to the correct data. + Assert.assertTrue(map.lookup(keyData, BYTE_ARRAY_OFFSET, recordLengthBytes).isDefined()); + Assert.assertEquals(recordLengthBytes, loc.getKeyLength()); + Assert.assertEquals(recordLengthBytes, loc.getValueLength()); + Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes)); + Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes)); + + try { + loc.putNewKey( + keyData, + BYTE_ARRAY_OFFSET, + recordLengthBytes, + valueData, + BYTE_ARRAY_OFFSET, + recordLengthBytes + ); + Assert.fail("Should not be able to set a new value for a key"); + } catch (AssertionError e) { + // Expected exception; do nothing. + } + } finally { + map.free(); + } + } + + @Test + public void iteratorTest() throws Exception { + final int size = 128; + BytesToBytesMap map = new BytesToBytesMap(memoryManager, size / 2); + try { + for (long i = 0; i < size; i++) { + final long[] value = new long[] { i }; + final BytesToBytesMap.Location loc = + map.lookup(value, PlatformDependent.LONG_ARRAY_OFFSET, 8); + Assert.assertFalse(loc.isDefined()); + loc.putNewKey( + value, + PlatformDependent.LONG_ARRAY_OFFSET, + 8, + value, + PlatformDependent.LONG_ARRAY_OFFSET, + 8 + ); + } + final java.util.BitSet valuesSeen = new java.util.BitSet(size); + final Iterator iter = map.iterator(); + while (iter.hasNext()) { + final BytesToBytesMap.Location loc = iter.next(); + Assert.assertTrue(loc.isDefined()); + final MemoryLocation keyAddress = loc.getKeyAddress(); + final MemoryLocation valueAddress = loc.getValueAddress(); + final long key = PlatformDependent.UNSAFE.getLong( + keyAddress.getBaseObject(), keyAddress.getBaseOffset()); + final long value = PlatformDependent.UNSAFE.getLong( + valueAddress.getBaseObject(), valueAddress.getBaseOffset()); + Assert.assertEquals(key, value); + valuesSeen.set((int) value); + } + Assert.assertEquals(size, valuesSeen.cardinality()); + } finally { + map.free(); + } + } + + @Test + public void randomizedStressTest() { + final int size = 65536; + // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays + // into ByteBuffers in order to use them as keys here. + final Map expected = new HashMap(); + final BytesToBytesMap map = new BytesToBytesMap(memoryManager, size); + + try { + // Fill the map to 90% full so that we can trigger probing + for (int i = 0; i < size * 0.9; i++) { + final byte[] key = getRandomByteArray(rand.nextInt(256) + 1); + final byte[] value = getRandomByteArray(rand.nextInt(512) + 1); + if (!expected.containsKey(ByteBuffer.wrap(key))) { + expected.put(ByteBuffer.wrap(key), value); + final BytesToBytesMap.Location loc = map.lookup( + key, + BYTE_ARRAY_OFFSET, + key.length + ); + Assert.assertFalse(loc.isDefined()); + loc.putNewKey( + key, + BYTE_ARRAY_OFFSET, + key.length, + value, + BYTE_ARRAY_OFFSET, + value.length + ); + // After calling putNewKey, the following should be true, even before calling + // lookup(): + Assert.assertTrue(loc.isDefined()); + Assert.assertEquals(key.length, loc.getKeyLength()); + Assert.assertEquals(value.length, loc.getValueLength()); + Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), key.length)); + Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), value.length)); + } + } + + for (Map.Entry entry : expected.entrySet()) { + final byte[] key = entry.getKey().array(); + final byte[] value = entry.getValue(); + final BytesToBytesMap.Location loc = map.lookup(key, BYTE_ARRAY_OFFSET, key.length); + Assert.assertTrue(loc.isDefined()); + Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength())); + Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength())); + } + } finally { + map.free(); + } + } +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java new file mode 100644 index 0000000000000..5a10de49f54fe --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.map; + +import org.apache.spark.unsafe.memory.MemoryAllocator; + +public class BytesToBytesMapOffHeapSuite extends AbstractBytesToBytesMapSuite { + + @Override + protected MemoryAllocator getMemoryAllocator() { + return MemoryAllocator.UNSAFE; + } + +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java new file mode 100644 index 0000000000000..12cc9b25d93b3 --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.map; + +import org.apache.spark.unsafe.memory.MemoryAllocator; + +public class BytesToBytesMapOnHeapSuite extends AbstractBytesToBytesMapSuite { + + @Override + protected MemoryAllocator getMemoryAllocator() { + return MemoryAllocator.HEAP; + } + +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java new file mode 100644 index 0000000000000..932882f1ca248 --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import org.junit.Assert; +import org.junit.Test; + +public class TaskMemoryManagerSuite { + + @Test + public void leakedNonPageMemoryIsDetected() { + final TaskMemoryManager manager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + manager.allocate(1024); // leak memory + Assert.assertEquals(1024, manager.cleanUpAllAllocatedMemory()); + } + + @Test + public void leakedPageMemoryIsDetected() { + final TaskMemoryManager manager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + manager.allocatePage(4096); // leak memory + Assert.assertEquals(4096, manager.cleanUpAllAllocatedMemory()); + } + +} diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 93ae45133ce24..70cb57ffd8c69 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -95,14 +95,8 @@ private[spark] class ApplicationMaster( val fs = FileSystem.get(yarnConf) - Utils.addShutdownHook { () => - // If the SparkContext is still registered, shut it down as a best case effort in case - // users do not call sc.stop or do System.exit(). - val sc = sparkContextRef.get() - if (sc != null) { - logInfo("Invoking sc stop from shutdown hook") - sc.stop() - } + // This shutdown hook should run *after* the SparkContext is shut down. + Utils.addShutdownHook(Utils.SPARK_CONTEXT_SHUTDOWN_PRIORITY - 1) { () => val maxAppAttempts = client.getMaxRegAttempts(sparkConf, yarnConf) val isLastAttempt = client.getAttemptId().getAttemptId() >= maxAppAttempts diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 019afbd1a1743..4abcf7307a388 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -39,7 +39,7 @@ import org.apache.hadoop.io.Text import org.apache.hadoop.mapred.Master import org.apache.hadoop.mapreduce.MRJobConfig import org.apache.hadoop.security.{Credentials, UserGroupInformation} -import org.apache.hadoop.security.token.Token +import org.apache.hadoop.security.token.{TokenIdentifier, Token} import org.apache.hadoop.util.StringUtils import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.ApplicationConstants.Environment @@ -226,6 +226,7 @@ private[spark] class Client( val distributedUris = new HashSet[String] obtainTokensForNamenodes(nns, hadoopConf, credentials) obtainTokenForHiveMetastore(hadoopConf, credentials) + obtainTokenForHBase(hadoopConf, credentials) val replication = sparkConf.getInt("spark.yarn.submit.file.replication", fs.getDefaultReplication(dst)).toShort @@ -354,7 +355,7 @@ private[spark] class Client( val dir = new File(path) if (dir.isDirectory()) { dir.listFiles().foreach { file => - if (!hadoopConfFiles.contains(file.getName())) { + if (file.isFile && !hadoopConfFiles.contains(file.getName())) { hadoopConfFiles(file.getName()) = file } } @@ -1084,6 +1085,41 @@ object Client extends Logging { } } + /** + * Obtain security token for HBase. + */ + def obtainTokenForHBase(conf: Configuration, credentials: Credentials): Unit = { + if (UserGroupInformation.isSecurityEnabled) { + val mirror = universe.runtimeMirror(getClass.getClassLoader) + + try { + val confCreate = mirror.classLoader. + loadClass("org.apache.hadoop.hbase.HBaseConfiguration"). + getMethod("create", classOf[Configuration]) + val obtainToken = mirror.classLoader. + loadClass("org.apache.hadoop.hbase.security.token.TokenUtil"). + getMethod("obtainToken", classOf[Configuration]) + + logDebug("Attempting to fetch HBase security token.") + + val hbaseConf = confCreate.invoke(null, conf) + val token = obtainToken.invoke(null, hbaseConf).asInstanceOf[Token[TokenIdentifier]] + credentials.addToken(token.getService, token) + + logInfo("Added HBase security token to credentials.") + } catch { + case e:java.lang.NoSuchMethodException => + logInfo("HBase Method not found: " + e) + case e:java.lang.ClassNotFoundException => + logDebug("HBase Class not found: " + e) + case e:java.lang.NoClassDefFoundError => + logDebug("HBase Class not found: " + e) + case e:Exception => + logError("Exception when obtaining HBase security token: " + e) + } + } + } + /** * Return whether the two file systems are the same. */