diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index dfa1d399a1ab..3099de56fcd9 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -24,6 +24,7 @@ import java.lang.reflect.Constructor import java.net.URI import java.util.{Arrays, Properties, UUID} import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.CountDownLatch import java.util.UUID.randomUUID import scala.collection.{Map, Set} @@ -196,6 +197,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli private[spark] val conf = config.clone() conf.validateSettings() + /** Has the server been marked for start. */ + val startLatch = new CountDownLatch(1) + /** * Return a copy of this SparkContext's configuration. The configuration ''cannot'' be * changed at runtime. @@ -1441,6 +1445,36 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli }.getOrElse(Utils.getCallSite()) } + /** + * Run a function on a given set of partitions in an RDD and pass the results to the given + * handler function. This is the main entry point for all actions in Spark. The allowLocal + * flag specifies whether the scheduler can run the computation on the driver rather than + * shipping it out to the cluster, for short actions like first(). + */ + def runJobWithPS[T, U: ClassTag]( + rdd: RDD[T], + func: (TaskContext, Iterator[T]) => U): Array[U] = { + if (stopped) { + throw new IllegalStateException("SparkContext has been shutdown") + } + val callSite = getCallSite + val cleanedFunc = clean(func) + logInfo("Starting run parameter server job: " + callSite.shortForm) + if (conf.getBoolean("spark.logLineage", false)) { + logInfo("RDD's recursive dependencies:\n" + rdd.toDebugString) + } + val results = new Array[U](rdd.partitions.size) + val resultHandler: (Int, U) => Unit = (pid, value) => { + logInfo(s"partition number $pid, value: $value") + results(pid) = value + } + dagScheduler.runJob(rdd, cleanedFunc, 0 until rdd.partitions.size, + callSite, false, resultHandler, localProperties.get) + progressBar.foreach(_.finishAll()) + rdd.doCheckpoint() + results + } + /** * Run a function on a given set of partitions in an RDD and pass the results to the given * handler function. This is the main entry point for all actions in Spark. The allowLocal diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 7d7fe1a44631..d7077130cf1f 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -21,6 +21,7 @@ import java.io.Serializable import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics +import org.apache.spark.ps.PSClient import org.apache.spark.util.TaskCompletionListener @@ -133,4 +134,6 @@ abstract class TaskContext extends Serializable { /** ::DeveloperApi:: */ @DeveloperApi def taskMetrics(): TaskMetrics + + def getPSClient: PSClient } diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 337c8e4ebebc..bdcecb428dc9 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -18,6 +18,7 @@ package org.apache.spark import org.apache.spark.executor.TaskMetrics +import org.apache.spark.ps.PSClient import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} import scala.collection.mutable.ArrayBuffer @@ -27,6 +28,7 @@ private[spark] class TaskContextImpl( val partitionId: Int, override val taskAttemptId: Long, override val attemptNumber: Int, + val psClient: Option[PSClient] = None, val runningLocally: Boolean = false, val taskMetrics: TaskMetrics = TaskMetrics.empty) extends TaskContext @@ -92,5 +94,7 @@ private[spark] class TaskContextImpl( override def isRunningLocally(): Boolean = runningLocally override def isInterrupted(): Boolean = interrupted + + override def getPSClient: PSClient = psClient.get } diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala index 415bd5059169..e80dd5ed2746 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala @@ -41,6 +41,10 @@ private[spark] class ClientArguments(args: Array[String]) { var supervise: Boolean = DEFAULT_SUPERVISE var memory: Int = DEFAULT_MEMORY var cores: Int = DEFAULT_CORES + var psServerMemory = DEFAULT_MEMORY + var psServerCores = DEFAULT_CORES + var numPSservers = DEFAULT_NUMBER_PS_SERVERS + var enablePS: Boolean = false private var _driverOptions = ListBuffer[String]() def driverOptions = _driverOptions.toSeq @@ -58,6 +62,22 @@ private[spark] class ClientArguments(args: Array[String]) { memory = value parse(tail) + case ("--num-servers") :: IntParam(value) :: tail => + numPSservers = value + parse(tail) + + case ("--server-memory") :: MemoryParam(value) :: tail => + psServerMemory = value + parse(tail) + + case ("--server-cores") :: IntParam(value) :: tail => + psServerCores = value + parse(tail) + + case ("--enablePS") :: value :: tail => + enablePS = value.toBoolean + parse(tail) + case ("--supervise" | "-s") :: tail => supervise = true parse(tail) @@ -119,6 +139,7 @@ private[spark] class ClientArguments(args: Array[String]) { object ClientArguments { private[spark] val DEFAULT_CORES = 1 private[spark] val DEFAULT_MEMORY = 512 // MB + private[spark] val DEFAULT_NUMBER_PS_SERVERS = 1 private[spark] val DEFAULT_SUPERVISE = false def isValidJarUrl(s: String): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 4a74641f4e1f..d75eedb96f20 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -367,8 +367,12 @@ object SparkSubmit { OptionAssigner(args.driverCores, YARN, CLUSTER, clOption = "--driver-cores"), OptionAssigner(args.queue, YARN, CLUSTER, clOption = "--queue"), OptionAssigner(args.numExecutors, YARN, CLUSTER, clOption = "--num-executors"), + OptionAssigner(args.numPSServers, YARN, CLUSTER, clOption = "--num-servers"), OptionAssigner(args.executorMemory, YARN, CLUSTER, clOption = "--executor-memory"), OptionAssigner(args.executorCores, YARN, CLUSTER, clOption = "--executor-cores"), + OptionAssigner(args.psServerMemory, YARN, CLUSTER, clOption = "--server-memory"), + OptionAssigner(args.psServerCores, YARN, CLUSTER, clOption = "--server-cores"), + OptionAssigner(args.enablePS.toString, YARN, CLUSTER, clOption = "--enablePS"), OptionAssigner(args.files, YARN, CLUSTER, clOption = "--files"), OptionAssigner(args.archives, YARN, CLUSTER, clOption = "--archives"), OptionAssigner(args.jars, YARN, CLUSTER, clOption = "--addJars"), diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 82e66a374249..04221cf1463b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -34,6 +34,8 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St var deployMode: String = null var executorMemory: String = null var executorCores: String = null + var psServerMemory: String = null + var psServerCores: String = null var totalExecutorCores: String = null var propertiesFile: String = null var driverMemory: String = null @@ -42,6 +44,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St var driverExtraJavaOptions: String = null var queue: String = null var numExecutors: String = null + var numPSServers: String = null var files: String = null var archives: String = null var mainClass: String = null @@ -52,6 +55,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St var packages: String = null var repositories: String = null var ivyRepoPath: String = null + var enablePS: Boolean = false var verbose: Boolean = false var isPython: Boolean = false var pyFiles: String = null @@ -315,6 +319,10 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St numExecutors = value parse(tail) + case ("--num-servers") :: value :: tail => + numPSServers = value + parse(tail) + case ("--total-executor-cores") :: value :: tail => totalExecutorCores = value parse(tail) @@ -327,6 +335,14 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St executorMemory = value parse(tail) + case ("--server-cores") :: value :: tail => + psServerCores = value + parse(tail) + + case ("--server-memory") :: value :: tail => + psServerMemory = value + parse(tail) + case ("--driver-memory") :: value :: tail => driverMemory = value parse(tail) @@ -417,6 +433,10 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St verbose = true parse(tail) + case ("--enablePS") :: tail => + enablePS = true + parse(tail) + case ("--version") :: tail => SparkSubmit.printVersionAndExit() diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index dd19e4947db1..70ef02ca30fc 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -17,8 +17,10 @@ package org.apache.spark.executor +import java.io._ import java.net.URL import java.nio.ByteBuffer +import java.lang.management.ManagementFactory import scala.collection.mutable import scala.concurrent.Await @@ -31,9 +33,15 @@ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv} import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.worker.WorkerWatcher +import org.apache.spark.ps.CoarseGrainedParameterServerMessage.NotifyClient +import org.apache.spark.ps.{PSClient, ServerInfo} import org.apache.spark.scheduler.TaskDescription import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.KillTask +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RegisterExecutor +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RegisterExecutorFailed +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.LaunchTask +import org.apache.spark.util._ private[spark] class CoarseGrainedExecutorBackend( driverUrl: String, @@ -48,6 +56,8 @@ private[spark] class CoarseGrainedExecutorBackend( var executor: Executor = null var driver: ActorSelection = null + val psServers = new mutable.HashMap[Long, ServerInfo]() + var psClient: Option[PSClient] = None override def preStart() { logInfo("Connecting to driver: " + driverUrl) @@ -72,6 +82,22 @@ private[spark] class CoarseGrainedExecutorBackend( logError("Slave registration failed: " + message) System.exit(1) + case AddNewPSServer(serverInfo) => + println("Adding new ps server") + psClient.synchronized { + if (!psClient.isDefined) { + val clientId: String = executorId + psClient = Some(new PSClient(clientId, context, env.conf)) + } + psClient.get.addServer(serverInfo) + val serverId = serverInfo.serverId + psServers(serverId) = serverInfo + } + + case NotifyClient(message) => + println(s"message:$message") + psClient.get.notifyTasks() + case LaunchTask(data) => if (executor == null) { logError("Received LaunchTask command but executor was null") @@ -110,6 +136,8 @@ private[spark] class CoarseGrainedExecutorBackend( override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) { driver ! StatusUpdate(executorId, taskId, state, data) } + + override def getPSClient: Option[PSClient] = psClient } private[spark] object CoarseGrainedExecutorBackend extends Logging { diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index c6ff38d527d8..54131afd0f70 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -200,7 +200,8 @@ private[spark] class Executor( // Run the actual task and measure its runtime. taskStart = System.currentTimeMillis() - val value = task.run(taskAttemptId = taskId, attemptNumber = attemptNumber) + val psClient = execBackend.getPSClient + val value = task.run(psClient, taskAttemptId = taskId, attemptNumber = attemptNumber) val taskFinish = System.currentTimeMillis() // If the task has been killed, let's fail it. diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala index e07cb31cbe4b..52ceef62a680 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala @@ -20,11 +20,14 @@ package org.apache.spark.executor import java.nio.ByteBuffer import org.apache.spark.TaskState.TaskState +import org.apache.spark.ps.PSClient /** * A pluggable interface used by the Executor to send updates to the cluster scheduler. */ private[spark] trait ExecutorBackend { def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) + + def getPSClient: Option[PSClient] } diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index cfd672e1d8a9..723917fa0ba3 100644 --- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -28,6 +28,7 @@ import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _} import org.apache.spark.{Logging, TaskState, SparkConf, SparkEnv} import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.ps.PSClient import org.apache.spark.scheduler.cluster.mesos.{MesosTaskLaunchData} import org.apache.spark.util.{SignalLogger, Utils} @@ -48,6 +49,8 @@ private[spark] class MesosExecutorBackend .build()) } + override def getPSClient: Option[PSClient] = None + override def registered( driver: ExecutorDriver, executorInfo: ExecutorInfo, diff --git a/core/src/main/scala/org/apache/spark/ps/CoarseGrainedParameterServerBackend.scala b/core/src/main/scala/org/apache/spark/ps/CoarseGrainedParameterServerBackend.scala new file mode 100644 index 000000000000..9fe264038e65 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ps/CoarseGrainedParameterServerBackend.scala @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ps + +import java.net.URL +import java.nio.ByteBuffer + +import scala.concurrent.Await +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import akka.pattern.Patterns +import akka.remote.RemotingLifecycleEvent +import akka.actor.{Props, ActorSelection, Actor} + +import org.apache.spark.executor.ExecutorBackend +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.ps.CoarseGrainedParameterServerMessage._ +import org.apache.spark.{SecurityManager, SparkConf, SparkEnv, Logging} +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ +import org.apache.spark.TaskState._ +import org.apache.spark.util._ + +/** + * `parameter server` ServerBackend + */ +private[spark] class CoarseGrainedParameterServerBackend( + driverUrl: String, + executorId: String, + hostPort: String, + executorVCores: Int, + userClassPath: Seq[URL], + env: SparkEnv) + extends Actor with ActorLogReceive with ExecutorBackend with Logging { + + Utils.checkHostPort(hostPort, "Expected hostport") + + var driver: ActorSelection = null + + var psServer: PSServer = null + + override def preStart() { + logInfo("Connecting to driver: " + driverUrl) + driver = context.actorSelection(driverUrl) + driver ! RegisterServer(executorId, hostPort, executorVCores, + System.getProperty("spark.yarn.container.id", "")) + context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + } + + override def receiveWithLogging = { + case SetParameter(key: String, value: Array[Double], clock: Int) => + logInfo(s"set $key to $value") + val success: Boolean = psServer.setParameter[String, Array[Double]](key, value, clock) + sender ! success + + case GetParameter(key: String, clock: Int) => + logInfo(s"request $key") + val (success, value) = psServer.getParameter[String, Array[Double]](key, clock) + logInfo(s"get parameter: $value") + sender ! Parameter(success, value) + + case UpdateParameter(key: String, value: Array[Double], clock: Int) => + logInfo(s"update $key") + val success: Boolean = psServer.updateParameter[String, Array[Double]](key, value, clock) + sender ! success + + case RegisteredServer(serverId: Long) => + logInfo(s"registered server with serverId: $serverId") + val sparkConf = env.conf + val agg = (deltaKVs: ArrayBuffer[Array[Double]]) + => { + val size = deltaKVs.size + deltaKVs.reduce((a, b) => a.zip(b).map(e => { + e._1 + e._2 + })).map(e => e / size) + } + val func = (arr1: Array[Double], arr2: Array[Double]) => arr1.zip(arr2).map(e => e._1 + e._2) + psServer = new PSServer(context, sparkConf, serverId, agg, func) + + case UpdateClock(clientId: String, clock: Int) => + logInfo(s"update clock $clock from client $clientId") + val pause: Boolean = psServer.updateClock(clientId, clock) + sender ! pause + + case InitPSClient(clientId: String) => + logInfo(s"client $clientId is coming.") + psServer.initPSClient(clientId) + + case NotifyServer(executorUrl: String) => + println("notify server: " + executorUrl) + psServer.addPSClient(executorUrl) + + } + + override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) { + + } + + override def getPSClient: Option[PSClient] = None + +} + +object CoarseGrainedParameterServerBackend extends Logging { + + private def run( + driverUrl: String, + executorId: String, + hostname: String, + cores: Int, + appId: String, + workerUrl: Option[String], + userClassPath: Seq[URL]) { + + SignalLogger.register(log) + + SparkHadoopUtil.get.runAsSparkUser { () => + // Debug code + Utils.checkHost(hostname) + + // Bootstrap to fetch the driver's Spark properties. + val executorConf = new SparkConf + val port = executorConf.getInt("spark.executor.port", 0) + val (fetcher, _) = AkkaUtils.createActorSystem( + "driverPropsFetcher", + hostname, + port, + executorConf, + new SecurityManager(executorConf)) + val driver = fetcher.actorSelection(driverUrl) + val timeout = AkkaUtils.askTimeout(executorConf) + val fut = Patterns.ask(driver, RetrieveSparkProps, timeout) + val props = Await.result(fut, timeout).asInstanceOf[Seq[(String, String)]] ++ + Seq[(String, String)](("spark.app.id", appId)) + fetcher.shutdown() + + // Create SparkEnv using properties we fetched from the driver. + val driverConf = new SparkConf() + val executorVCores = cores * driverConf.get("spark.cores.ratio", "1").toInt + for ((key, value) <- props) { + // this is required for SSL in standalone mode + if (SparkConf.isExecutorStartupConf(key)) { + driverConf.setIfMissing(key, value) + } else { + driverConf.set(key, value) + } + } + val env = SparkEnv.createExecutorEnv( + driverConf, executorId, hostname, port, executorVCores, isLocal = false) + + // SparkEnv sets spark.driver.port so it shouldn't be 0 anymore. + val boundPort = env.conf.getInt("spark.executor.port", 0) + assert(boundPort != 0) + + // Start the CoarseGrainedExecutorBackend actor. + val sparkHostPort = hostname + ":" + boundPort + env.actorSystem.actorOf( + Props(classOf[CoarseGrainedParameterServerBackend], + driverUrl, executorId, sparkHostPort, executorVCores, userClassPath, env), + name = "Server") + + env.actorSystem.awaitTermination() + } + } + + def main(args: Array[String]) { + var driverUrl: String = null + var executorId: String = null + var hostname: String = null + var cores: Int = 0 + var appId: String = null + var workerUrl: Option[String] = None + val userClassPath = new mutable.ListBuffer[URL]() + + var argv = args.toList + while (!argv.isEmpty) { + argv match { + case ("--driver-url") :: value :: tail => + driverUrl = value + argv = tail + case ("--executor-id") :: value :: tail => + executorId = value + argv = tail + case ("--hostname") :: value :: tail => + hostname = value + argv = tail + case ("--cores") :: value :: tail => + cores = value.toInt + argv = tail + case ("--app-id") :: value :: tail => + appId = value + argv = tail + case ("--worker-url") :: value :: tail => + // Worker url is used in spark standalone mode to enforce fate-sharing with worker + workerUrl = Some(value) + argv = tail + case ("--user-class-path") :: value :: tail => + userClassPath += new URL(value) + argv = tail + case Nil => + case tail => + System.err.println(s"Unrecognized options: ${tail.mkString(" ")}") + printUsageAndExit() + } + } + + if (driverUrl == null || executorId == null || hostname == null || cores <= 0 || + appId == null) { + printUsageAndExit() + } + + run(driverUrl, executorId, hostname, cores, appId, workerUrl, userClassPath) + } + + private def printUsageAndExit() = { + System.err.println( + """ + |"Usage: CoarseGrainedParameterServerBackend [options] + | + | Options are: + | --driver-url + | --executor-id + | --hostname + | --cores + | --app-id + | --worker-url + | --user-class-path + |""".stripMargin) + System.exit(1) + } +} diff --git a/core/src/main/scala/org/apache/spark/ps/CoarseGrainedParameterServerMessage.scala b/core/src/main/scala/org/apache/spark/ps/CoarseGrainedParameterServerMessage.scala new file mode 100644 index 000000000000..b5b2f27faa92 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ps/CoarseGrainedParameterServerMessage.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ps + +/** + * CoarseGrainedParameterServerMessage + */ +private[spark] sealed trait CoarseGrainedParameterServerMessage extends Serializable + +private[spark] object CoarseGrainedParameterServerMessage { + + case class SetParameter(key: String, value: Array[Double], clock: Int) + extends CoarseGrainedParameterServerMessage + + case class GetParameter(key: String, clock: Int) + extends CoarseGrainedParameterServerMessage + + case class UpdateParameter(key: String, value: Array[Double], clock: Int) + extends CoarseGrainedParameterServerMessage + + case class Parameter(success: Boolean, value: Array[Double]) + extends CoarseGrainedParameterServerMessage + + case class UpdateClock(clientId: String, clock: Int) + extends CoarseGrainedParameterServerMessage + + case class NotifyServer(executorUrl: String) + extends CoarseGrainedParameterServerMessage + + case class InitPSClient(clientId: String) + extends CoarseGrainedParameterServerMessage + + case class NotifyClient(message: String) + extends CoarseGrainedParameterServerMessage + +} diff --git a/core/src/main/scala/org/apache/spark/ps/PSClient.scala b/core/src/main/scala/org/apache/spark/ps/PSClient.scala new file mode 100644 index 000000000000..cb8bab2e8d1b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ps/PSClient.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ps + +import scala.collection.mutable._ +import scala.concurrent.Await + +import akka.actor.{ActorContext, ActorSelection} +import akka.pattern.ask + +import org.apache.spark.ps.CoarseGrainedParameterServerMessage._ +import org.apache.spark.util.AkkaUtils +import org.apache.spark.{SparkConf, SparkEnv} + +/** + * Client Role in `Parameter Server` + */ +class PSClient(clientId: String, context: ActorContext, conf: SparkConf) { + val interval = conf.getInt("spark.ps.server.heartbeatInterval", 1000) + val timeout = AkkaUtils.lookupTimeout(conf) + val retryAttempts = AkkaUtils.numRetries(conf) + val retryIntervalMs = AkkaUtils.retryWaitMs(conf) + val serverId2ActorRef = new HashMap[Long, ActorSelection]() + var currentClock: ThreadLocal[Int] = new ThreadLocal[Int] { + override def initialValue(): Int = 0 + } + val waiting: String = "waiting" + + def addServer(serverInfo: ServerInfo) { + val serverRef = context.actorSelection(serverInfo.serverUrl) + serverId2ActorRef(serverInfo.serverId) = serverRef + serverRef ! InitPSClient(clientId) + } + + def get(key: String): Array[Double] = { + val message = GetParameter(key, currentClock.get()) + val serverRef = serverId2ActorRef.head._2 + val future = serverRef.ask(message)(timeout) + val response = Await.result(future, timeout) + response.asInstanceOf[Parameter].value + } + + def set(key: String, value: Array[Double]): Unit = { + val serverRef = serverId2ActorRef.head._2 + serverRef ! SetParameter(key, value, currentClock.get()) + } + + def update(key: String, value: Array[Double]): Unit = { + val serverRef = serverId2ActorRef.head._2 + serverRef ! UpdateParameter(key, value, currentClock.get()) + } + + def serverActorRef(url: String) = { + SparkEnv.get.actorSystem.actorSelection(url) + } + + def clock(): Unit = { + val cc = currentClock.get() + val serverRef = serverId2ActorRef.head._2 + val message = UpdateClock(clientId, cc) + val future = serverRef.ask(message)(timeout) + val pause = Await.result(future, timeout).asInstanceOf[Boolean] + if (pause) { + waiting.synchronized { + waiting.wait() + } + } + currentClock.set(cc + 1) + } + + def initClock(clock: Int): Unit = { + currentClock.set(clock) + } + + def notifyTasks(): Unit = { + waiting.synchronized { + waiting.notifyAll() + } + } +} diff --git a/core/src/main/scala/org/apache/spark/ps/PSContext.scala b/core/src/main/scala/org/apache/spark/ps/PSContext.scala new file mode 100644 index 000000000000..5b93a1c863be --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ps/PSContext.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ps + +import scala.collection.mutable + +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.rdd.RDD + +/** + * Parameter Server Context + */ +class PSContext(sc : SparkContext) extends Logging { + + initialize() + + private def initialize() { + sc.startLatch.await() + } + + def registerPSTable[T](model: RDD[T]): Int = { + -1 + } + + def loadPSModel(model: RDD[(String, Array[Double])]): Unit = { + model.runWithPS[Unit](1, (array, psClient) => { + array.foreach(e => { + psClient.set(e._1, e._2) + }) + }) + } + + def downloadPSModel(keys: Array[String], numPartition: Int): Array[Array[Double]] = { + sc.parallelize(keys, numPartition) + .runWithPS[Array[Array[Double]]](1, (array, psClient) => { + val res = new mutable.ArrayBuffer[Array[Double]] + array.foreach(e => { + val value = psClient.get(e) + res.+=(value) + }) + res.toArray + }).flatMap(arrs => arrs.toIterator) + } +} diff --git a/core/src/main/scala/org/apache/spark/ps/PSServer.scala b/core/src/main/scala/org/apache/spark/ps/PSServer.scala new file mode 100644 index 000000000000..a4000db9734d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ps/PSServer.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ps + +import scala.reflect.ClassTag +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import akka.actor.ActorContext + +import org.apache.spark.SparkConf +import org.apache.spark.ps.CoarseGrainedParameterServerMessage.NotifyClient +import org.apache.spark.ps.storage.PSStorage + +/** + * Server Role in `Parameter Server` + */ +private[spark] class PSServer( + context: ActorContext, + sparkConf: SparkConf, + serverId: Long, + agg: ArrayBuffer[Array[Double]] => Array[Double], + func: (Array[Double], Array[Double]) => Array[Double]) { + + + private val psKVStorage: PSStorage = PSStorage.getKVStorage(sparkConf) + + private val inValidV = new Array[Double](0) + private var globalClock: Int = 0 + private val staleClock: Int = sparkConf.get("spark.ps.stale.clock", "0").toInt + private val allTaskClientClock = new mutable.HashMap[String, Int]() + private val allTaskClients = new mutable.HashSet[String]() + + def getParameter[K: ClassTag, V: ClassTag](key: K, clock: Int): (Boolean, V) = { + if (checkValidity(clock)) { + (true, psKVStorage.get[K, V](key).get) + } else { + (false, inValidV.asInstanceOf[V]) + } + } + + def setParameter[K: ClassTag, V: ClassTag](key: K, value: V, clock: Int): Boolean = { + if (checkValidity(clock)) { + psKVStorage.put[K, V](key, value) + true + } else { + false + } + } + + def updateParameter[K: ClassTag, V: ClassTag](key: K, value: V, clock: Int): Boolean = { + if (checkValidity(clock)) { + psKVStorage.update[K, V](key, value) + true + } else { + false + } + } + + /** + * batch update parameter from PSTask + * every k-v pair has same array index + * @param keys ordered keys + * @param values ordered values + */ + def batchUpdateParameter[K: ClassTag, V: ClassTag](keys: Array[K], values: Array[V]): Unit = { + + } + + def updateClock(clientId: String, clock: Int): Boolean = { + println(s"global clock: $globalClock, clock: $clock") + allTaskClientClock(clientId) = clock + allTaskClientClock.synchronized { + val slower = allTaskClientClock.filter(_._2 < globalClock) + if (slower.size == 0) { + update() + globalClock += 1 + notifyAllClients() + } + } + + if ((clock + 1) > globalClock + staleClock) { + true + } else { + false + } + } + + def initPSClient(clientId: String): Unit = { + allTaskClientClock(clientId) = -1 + } + + def addPSClient(executorUrl: String): Unit = { + allTaskClients.add(executorUrl) + } + + def checkValidity(clock: Int): Boolean = { + if (clock <= globalClock + staleClock) { + true + } else { + false + } + } + + def notifyAllClients(): Unit = { + println("notify all ps clients") + val message = "notify all ps clients" + allTaskClients.foreach(e => { + val executorRef = context.actorSelection(e) + executorRef ! NotifyClient(message) + }) + } + + private def update(): Unit = { + psKVStorage.applyDelta[String, Array[Double]](agg, func) + } +} diff --git a/core/src/main/scala/org/apache/spark/ps/PSServerManager.scala b/core/src/main/scala/org/apache/spark/ps/PSServerManager.scala new file mode 100644 index 000000000000..929f7f2e9bbf --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ps/PSServerManager.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ps + +import java.util.concurrent.atomic.AtomicLong + +import scala.collection.mutable.HashMap + +/** + * PSServerManager + */ +private[spark] class PSServerManager { + + val nextServerId = new AtomicLong(0) + + val containerId2Server = new HashMap[String, ServerInfo]() + val executorId2Server = new HashMap[String, ServerInfo]() + val executorId2ServerId = new HashMap[String, Long]() + + def addPSServer( + executorId: String, + hostPort: String, + containerId: String, + serverInfo: ServerInfo) { + containerId2Server(containerId) = serverInfo + executorId2Server(executorId) = serverInfo + executorId2ServerId(executorId) = serverInfo.serverId + } + + def newServerId(): Long = nextServerId.getAndIncrement + + def getAllServers: Iterator[(String, ServerInfo)] = executorId2Server.toArray.toIterator + +} diff --git a/core/src/main/scala/org/apache/spark/ps/ServerInfo.scala b/core/src/main/scala/org/apache/spark/ps/ServerInfo.scala new file mode 100644 index 000000000000..72de652bdf0b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ps/ServerInfo.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ps + +import akka.actor.Address + +/** + * ServerInfo + */ +private[spark] class ServerInfo( + val serverId: Long, + val serverUrl: String, + val serverAddress: Address, + val serverHost: String, + val totalCores: Int) extends Serializable diff --git a/core/src/main/scala/org/apache/spark/ps/storage/MemoryStorage.scala b/core/src/main/scala/org/apache/spark/ps/storage/MemoryStorage.scala new file mode 100644 index 000000000000..d74d55eae04e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ps/storage/MemoryStorage.scala @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ps.storage + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +import org.apache.spark.SparkConf + +/** + * memory kv storage + */ +class MemoryStorage(conf: SparkConf) extends PSStorage { + + private val keyValues = new mutable.HashMap[String, Array[Double]]() + private val updatedKeyValues = new mutable.HashMap[String, ArrayBuffer[Array[Double]]]() + + private val inValidV = new Array[Double](0) + + /** + * fetch value from parameter server storage. + * @param k: key + * @tparam K: type of Key + * @tparam V: type of Value + * @return value + */ + override def get[K: ClassTag, V: ClassTag](k: K): Option[V] = { + Some(keyValues.getOrElse(k.asInstanceOf[String], inValidV)).asInstanceOf[Option[V]] + } + + /** + * fetch multi values from parameter server storage. + * @param ks: multi keys + * @tparam K: type of Key + * @tparam V: type of Value + * @return multi values + */ + override def multiGet[K: ClassTag, V: ClassTag](ks: Array[K]): Array[Option[V]] = { + Array(None) + } + + /** + * put value into parameter server storage with specific key + * @param k: key + * @param v: value + * @tparam K: type of Key + * @tparam V: type of Value + * @return + */ + override def put[K: ClassTag, V: ClassTag](k: K, v: V): Boolean = { + keyValues.synchronized { + keyValues(k.asInstanceOf[String]) = v.asInstanceOf[Array[Double]] + } + + true + } + + /** + * put values into parameter server storage with specific key + * @param ks: keys + * @param vs: values + * @tparam K: type of Key + * @tparam V: type of Value + * @return + */ + override def multiPut[K: ClassTag, V: ClassTag](ks: Array[K], vs: Array[V]): Boolean = false + + override def update[K: ClassTag, V: ClassTag](k: K, v: V): Boolean = { + updatedKeyValues.synchronized { + if (updatedKeyValues.contains(k.asInstanceOf[String])) { + updatedKeyValues(k.asInstanceOf[String]) += v.asInstanceOf[Array[Double]] + } else { + val updatedValues = new ArrayBuffer[Array[Double]] + updatedValues += v.asInstanceOf[Array[Double]] + updatedKeyValues(k.asInstanceOf[String]) = updatedValues + } + } + + true + } + + /** + * whether caching data in memory + * @param bool flag + */ + override def setCacheInMemory(bool: Boolean = true): Unit = ??? + + /** + * clear data of specific table from parameter server storage. + * @param tbId: table ID + * @return + */ + override def clear(tbId: Long): Boolean = false + + /** + * clear current node(server or task) all data in memery and local disk + * @return + */ + override def clearAll(): Boolean = ??? + + /** + * check whether the parameter server storage contains specific key/value + * @param k: key + * @tparam K: type of key + * @return + */ + override def exists[K: ClassTag](k: K): Boolean = keyValues.contains(k.asInstanceOf[String]) + + /** + * iterator all data + * @tparam K: type of Key + * @tparam V: type of Value + * @return + */ + override def toIterator[K: ClassTag, V: ClassTag](): Iterator[(K, V)] = + keyValues.toIterator.asInstanceOf[Iterator[(K, V)]] + + /** + * apply `agg` and `func` to deltas and original values + * @param agg: functions applied to deltas + * @param func: functions allied to original values with specific deltas + * @tparam K: type of Key + * @tparam V: type of Value + */ + override def applyDelta[K: ClassTag, V: ClassTag](agg: ArrayBuffer[V] => V, func: (V, V) => V) + : Unit = { + updatedKeyValues.foreach( e => { + if (keyValues.contains(e._1)) { + val deltaV = agg(e._2.asInstanceOf[ArrayBuffer[V]]) + println(s"update agg delta ${deltaV.asInstanceOf[Array[Double]].mkString(", ")}") + keyValues(e._1) = func(keyValues(e._1).asInstanceOf[V], deltaV.asInstanceOf[V]). + asInstanceOf[Array[Double]] + println(s"set ${e._1} to ${keyValues(e._1).mkString(", ")}") + e._2.clear() + } + }) + } +} diff --git a/core/src/main/scala/org/apache/spark/ps/storage/PSStorage.scala b/core/src/main/scala/org/apache/spark/ps/storage/PSStorage.scala new file mode 100644 index 000000000000..441b55a6013e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ps/storage/PSStorage.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ps.storage + +import java.nio.charset.Charset + +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +import org.apache.spark.util.Utils +import org.apache.spark.{SparkConf, Logging} + +/** + * local storage interface for ``org.apache.spark.ps.PSServer`` + * and ``PSTask`` + * fetch data from server, cache in local + */ +private[ps] trait PSStorage extends Logging { + + lazy val charset = Charset.forName("UTF-8") + + /** + * fetch value from parameter server storage. + * @param k: key + * @tparam K: type of Key + * @tparam V: type of Value + * @return value + */ + def get[K: ClassTag, V: ClassTag](k: K): Option[V] + + /** + * fetch multi values from parameter server storage. + * @param ks: multi keys + * @tparam K: type of Key + * @tparam V: type of Value + * @return multi values + */ + def multiGet[K: ClassTag, V: ClassTag](ks: Array[K]): Array[Option[V]] + + /** + * put value into parameter server storage with specific key + * @param k: key + * @param v: value + * @tparam K: type of Key + * @tparam V: type of Value + * @return + */ + def put[K: ClassTag, V: ClassTag](k: K, v: V): Boolean + + /** + * put values into parameter server storage with specific key + * @param ks: keys + * @param vs: values + * @tparam K: type of Key + * @tparam V: type of Value + * @return + */ + def multiPut[K: ClassTag, V: ClassTag](ks: Array[K], vs: Array[V]): Boolean + + /** + * update value into parameter server storage with specific key + * @param k: key + * @param v: value + * @tparam K: type of Key + * @tparam V: type of Value + * @return + */ + def update[K: ClassTag, V: ClassTag](k: K, v: V): Boolean + + /** + * whether caching data in memory + * @param bool flag + */ + def setCacheInMemory(bool: Boolean = true): Unit + + /** + * clear data of specific table from parameter server storage. + * @param tbId: table ID + * @return + */ + def clear(tbId: Long): Boolean + + /** + * clear current node(server or task) all data in memery and local disk + * @return + */ + def clearAll(): Boolean + + /** + * check whether the parameter server storage contains specific key/value + * @param k: key + * @tparam K: type of key + * @return + */ + def exists[K: ClassTag](k: K): Boolean + + /** + * iterator all data + * @tparam K: type of Key + * @tparam V: type of Value + * @return + */ + def toIterator[K: ClassTag, V: ClassTag](): Iterator[(K, V)] + + /** + * apply `agg` and `func` to deltas and original values + * @param agg: functions applied to deltas + * @param func: functions allied to original values with specific deltas + * @tparam K: type of Key + * @tparam V: type of Value + */ + def applyDelta[K: ClassTag, V: ClassTag](agg: ArrayBuffer[V] => V, func: (V, V) => V): Unit +} + +private[ps] object PSStorage { + + private val configKey = "spark.ps.kv.storage" + private val FALLBACK_KV_STORAGE = "RocksDB" + + private val psKVStorageNames = Map( + "MemoryStorage" -> classOf[MemoryStorage].getName) + + def getKVStorage(sparkConf: SparkConf): PSStorage = { + val stName = sparkConf.get(configKey, "MemoryStorage") + val stClass = psKVStorageNames.getOrElse(stName, stName) + val st = try { + val ctor = Class.forName(stClass, true, Utils.getContextOrSparkClassLoader) + .getConstructor(classOf[SparkConf]) + Some(ctor.newInstance(sparkConf).asInstanceOf[PSStorage]) + } catch { + case e: ClassNotFoundException => None + case e: IllegalArgumentException => None + } + st.getOrElse(throw new IllegalArgumentException( + s"Key/Value Storage [$stClass] is not available. " + + s"Consider setting $configKey=$FALLBACK_KV_STORAGE")) + } +} diff --git a/core/src/main/scala/org/apache/spark/ps/strategy/ConsistencyStrategy.scala b/core/src/main/scala/org/apache/spark/ps/strategy/ConsistencyStrategy.scala new file mode 100644 index 000000000000..b08716136df3 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ps/strategy/ConsistencyStrategy.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ps.strategy + +/** + * ConsistencyStrategy + * Created by genmao.ygm on 15-3-19. + */ +private[ps] trait ConsistencyStrategy { + def doConsistent(clock: Int): Unit +} diff --git a/core/src/main/scala/org/apache/spark/ps/strategy/ModelPartitionStrategy.scala b/core/src/main/scala/org/apache/spark/ps/strategy/ModelPartitionStrategy.scala new file mode 100644 index 000000000000..f06be7e16879 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ps/strategy/ModelPartitionStrategy.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ps.strategy + +/** + * ModelPartitionStrategy + * Created by genmao.ygm on 15-3-19. + */ +private[ps] class ModelPartitionStrategy( + numRowPartitions: Int = 1, + numColPartitions: Int = 1, + numRows: Int = 1, + numCols: Int = 1) { + + def getLocInServers(row: Int, col: Int): (Int, Int, Int) = { + val rIdx = getIdxAtRange(row, numRows, numRowPartitions) + val cIdx = getIdxAtRange(col, numCols, numColPartitions) + val serverId = (rIdx - 1) * numRowPartitions + cIdx + (serverId, row, col) + } + + def getRowLocInServers(row: Int, col: Int): (Int, Int) = { + val rIdx = getIdxAtRange(row, numRows, numRowPartitions) + val cIdx = getIdxAtRange(col, numCols, numColPartitions) + val serverId = (rIdx - 1) * numRowPartitions + cIdx + (serverId, row) + } + + private def getIdxAtRange(n: Int, len: Int, range: Int): Int = n * range / len + 1 + +} diff --git a/core/src/main/scala/org/apache/spark/ps/util/Utils.scala b/core/src/main/scala/org/apache/spark/ps/util/Utils.scala new file mode 100644 index 000000000000..dd3716689cee --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ps/util/Utils.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ps.util + +import java.io.{IOException, File} + +import org.apache.spark.util.Utils._ +import org.apache.spark.{SparkConf, Logging} + +/** + * some common interface + */ +private[ps] object Utils extends Logging { + + /** Get parameter server local directories **/ + def getOrCreatePSLocalDirs(conf: SparkConf):Array[String] = { + conf.get("SPARK_PS_DIRS", conf.getOption("spark.ps.dir") + .getOrElse(System.getProperty("java.io.tmpdir"))) + .split(",") + .flatMap { dir => + try { + val file = new File(dir) + if (file.exists || file.mkdirs()) { + val tmpDir: File = createDirectory(dir) + chmod700(tmpDir) + Some(tmpDir.getAbsolutePath) + } else { + logError(s"Failed to create ps dir in $dir. Ignoring this directory.") + None + } + } catch { + case e: IOException => + logError(s"Failed to create ps root dir in $dir. Ignoring this directory.") + None + } + } + } + +} diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 3ab9e54f0ec5..b0d89bd787b9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -37,6 +37,7 @@ import org.apache.spark.partial.BoundedDouble import org.apache.spark.partial.CountEvaluator import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult +import org.apache.spark.ps.PSClient import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{BoundedPriorityQueue, Utils} import org.apache.spark.util.collection.OpenHashMap @@ -287,6 +288,18 @@ abstract class RDD[T: ClassTag]( new MapPartitionsRDD[U, T](this, (context, pid, iter) => iter.map(cleanF)) } + def runWithPS[U: ClassTag]( + tableId: Int, + func: (Array[T], PSClient) => U): Array[U] = { + val f: (TaskContext, Iterator[T]) => U = (taskContext: TaskContext, iter: Iterator[T]) => { + val client = taskContext.getPSClient + val arr = iter.toArray + client.initClock(0) + func(arr, client) + } + sc.runJobWithPS(this, f) + } + /** * Return a new RDD by first applying a function to all elements of this * RDD, and then flattening the results. diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 847a4912eec1..abb836e17ec9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -24,6 +24,7 @@ import scala.collection.mutable.HashMap import org.apache.spark.{TaskContextHelper, TaskContextImpl, TaskContext} import org.apache.spark.executor.TaskMetrics +import org.apache.spark.ps.PSClient import org.apache.spark.serializer.SerializerInstance import org.apache.spark.util.ByteBufferInputStream import org.apache.spark.util.Utils @@ -51,9 +52,9 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex * @param attemptNumber how many times this task has been attempted (0 for the first attempt) * @return the result of the task */ - final def run(taskAttemptId: Long, attemptNumber: Int): T = { - context = new TaskContextImpl(stageId = stageId, partitionId = partitionId, - taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, runningLocally = false) + final def run(psClient: Option[PSClient], taskAttemptId: Long, attemptNumber: Int): T = { + context = new TaskContextImpl( + stageId, partitionId, taskAttemptId, attemptNumber, psClient, false) TaskContextHelper.setTaskContext(context) context.taskMetrics.setHostname(Utils.localHostName()) taskThread = Thread.currentThread() diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 9bf74f4be198..72bb99074c33 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -19,6 +19,9 @@ package org.apache.spark.scheduler.cluster import java.nio.ByteBuffer +import scala.Predef._ + +import org.apache.spark.ps.ServerInfo import org.apache.spark.TaskState.TaskState import org.apache.spark.util.{SerializableBuffer, Utils} @@ -85,4 +88,16 @@ private[spark] object CoarseGrainedClusterMessages { case class KillExecutors(executorIds: Seq[String]) extends CoarseGrainedClusterMessage + case class RegisterServer( + executorId: String, + hostPort: String, + cores: Int, + containerId: String) extends CoarseGrainedClusterMessage { + Utils.checkHostPort(hostPort, "Expected host port") + } + + case class RegisteredServer(serverId: Long) extends CoarseGrainedClusterMessage + + case class AddNewPSServer(serverInfo: ServerInfo) extends CoarseGrainedClusterMessage + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 87ebf31139ce..e779f7fea3ea 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -28,6 +28,8 @@ import akka.pattern.ask import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import org.apache.spark.{ExecutorAllocationClient, Logging, SparkEnv, SparkException, TaskState} +import org.apache.spark.ps.CoarseGrainedParameterServerMessage.NotifyServer +import org.apache.spark.ps.{ServerInfo, PSServerManager} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Utils} @@ -68,6 +70,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste private val listenerBus = scheduler.sc.listenerBus + private val psServerManager = new PSServerManager + // Executors we have requested the cluster manager to kill that have not died yet private val executorsPendingToRemove = new HashSet[String] @@ -97,8 +101,15 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste addressToExecutorId(sender.path.address) = executorId totalCoreCount.addAndGet(cores) totalRegisteredExecutors.addAndGet(1) - val (host, _) = Utils.parseHostPort(hostPort) - val data = new ExecutorData(sender, sender.path.address, host, cores, cores, logUrls) + val (host, port) = Utils.parseHostPort(hostPort) + val executorUrl = AkkaUtils.address( + AkkaUtils.protocol(), + SparkEnv.executorActorSystemName, + host, + port, + "Executor") + val data = new ExecutorData(sender, sender.path.address, host, cores, + cores, logUrls, executorUrl) // This must be synchronized because variables mutated // in this block are read when requesting executors CoarseGrainedSchedulerBackend.this.synchronized { @@ -110,9 +121,28 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste } listenerBus.post( SparkListenerExecutorAdded(System.currentTimeMillis(), executorId, data)) + notifyExecutorAndServer(sender, "executor", executorUrl = executorUrl) makeOffers() } + case RegisterServer(executorId, hostPort, cores, containerId) => + println("RegisterServer") + val (host, port) = Utils.parseHostPort(hostPort) + CoarseGrainedSchedulerBackend.this.synchronized { + val serverId = psServerManager.newServerId() + val serverUrl = AkkaUtils.address( + AkkaUtils.protocol(), + SparkEnv.executorActorSystemName, + host, + port, + "Server") + val serverInfo = new ServerInfo(serverId, serverUrl, sender.path.address, host, cores) + psServerManager.addPSServer(executorId, hostPort, containerId, serverInfo) + sender ! RegisteredServer(serverId) + notifyExecutorAndServer(sender, "server", serverInfo = serverInfo) + checkPSInitialized() + } + case StatusUpdate(executorId, taskId, state, data) => scheduler.statusUpdate(taskId, state, data.value) if (TaskState.isFinished(state)) { @@ -162,6 +192,35 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste sender ! sparkProperties } + /** + * Notify executor that a new server has been added + * Or notify server that a new executor has been added + * @param sender : executor sender + * @param serverInfo : server information + */ + def notifyExecutorAndServer( + sender: ActorRef, + role: String, + serverInfo: ServerInfo = null, + executorUrl: String = null): Unit = { + synchronized { + role match { + case "executor" => + psServerManager.getAllServers.foreach(e => { + val serverRef = context.actorSelection(e._2.serverUrl) + serverRef ! NotifyServer(executorUrl) + sender ! AddNewPSServer(e._2) + }) + case "server" => + executorDataMap.foreach(e => { + val executorRef = e._2.executorActor + executorRef ! AddNewPSServer(serverInfo) + sender ! NotifyServer(e._2.executorUrl) + }) + } + } + } + // Make fake resource offers on all executors def makeOffers() { launchTasks(scheduler.resourceOffers(executorDataMap.map { case (id, executorData) => @@ -266,6 +325,15 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste } } + def checkPSInitialized(): Unit = { + val numServers = conf.get("spark.num.servers").toInt + assert(numServers > 0) + val currentServers = psServerManager.getAllServers.size + if (currentServers == numServers) { + scheduler.sc.startLatch.countDown() + } + } + override def reviveOffers() { driverActor ! ReviveOffers } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala index 5e571efe7672..047233348356 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala @@ -34,5 +34,6 @@ private[cluster] class ExecutorData( override val executorHost: String, var freeCores: Int, override val totalCores: Int, - override val logUrlMap: Map[String, String] + override val logUrlMap: Map[String, String], + val executorUrl: String ) extends ExecutorInfo(executorHost, totalCores, logUrlMap) diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 4676b828d3d8..a4f8cebf4d47 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -26,6 +26,7 @@ import akka.actor.{Actor, ActorRef, Props} import org.apache.spark.{Logging, SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} +import org.apache.spark.ps.PSClient import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} import org.apache.spark.util.ActorLogReceive @@ -127,6 +128,8 @@ private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: localActor ! StatusUpdate(taskId, state, serializedData) } + override def getPSClient: Option[PSClient] = None + override def applicationId(): String = appId } diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index dc9f77360d67..d6dcac0dafa8 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -987,7 +987,7 @@ public void persist() { @Test public void iterator() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); - TaskContext context = new TaskContextImpl(0, 0, 0L, 0, false, new TaskMetrics()); + TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, false, new TaskMetrics()); Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue()); } diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index 4b25c200a695..de16aec4d559 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -85,7 +85,7 @@ class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAf // Local computation should not persist the resulting value, so don't expect a put(). when(blockManager.get(RDDBlockId(0, 0))).thenReturn(None) - val context = new TaskContextImpl(0, 0, 0, 0, true) + val context = new TaskContextImpl(0, 0, 0, 0, None, true) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(1, 2, 3, 4)) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 057e22691602..34883ea195fe 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -45,7 +45,7 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte val task = new ResultTask[String, String]( 0, sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0) intercept[RuntimeException] { - task.run(0, 0) + task.run(None, 0, 0) } assert(TaskContextSuite.completed === true) } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PSLogisticRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PSLogisticRegressionExample.scala new file mode 100644 index 000000000000..a0b3d7494972 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PSLogisticRegressionExample.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import org.apache.spark.mllib.classification.PSLogisticRegression +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.{SparkContext, SparkConf} + + +object PSLogisticRegressionExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName(s"PSLogisticRegressionExample") + val sc = new SparkContext(conf) + + val input = args(0) + val numIterations = args(1).toInt + val stepSize = args(2).toDouble + val miniBatchFraction = args(3).toDouble + val partition = args(4).toInt + + val (training, test) = if (args.length == 6) { + val testFile = args(5) + + val training = MLUtils.loadLibSVMFile(sc, input, 123, partition).map { p => + val label = if (p.label == 1.0) 1.0 else 0.0 + LabeledPoint(label, p.features) + }.cache() + val test = MLUtils.loadLibSVMFile(sc, testFile, 123).map { p => + val label = if (p.label == 1.0) 1.0 else 0.0 + LabeledPoint(label, p.features) + }.cache() + + val numTraining = training.count() + val numTest = test.count() + println(s"Training: $numTraining, test: $numTest.") + + (training, test) + } else { + val examples = MLUtils.loadLibSVMFile(sc, input, -1, partition).map { p => + val label = if (p.label == 1.0) 1.0 else 0.0 + LabeledPoint(label, p.features) + }.cache() + val splits = examples.randomSplit(Array(0.8, 0.2)) + val training = splits(0).cache() + val test = splits(1).cache() + + val numTraining = training.count() + val numTest = test.count() + println(s"Training: $numTraining, test: $numTest.") + + examples.unpersist(blocking = false) + + (training, test) + } + + val numFeatureTraining = training.take(1).head.features.size + val numFeatureTest = test.take(1).head.features.size + println(s"feature in training: $numFeatureTraining, feature in test: $numFeatureTest") + + val model = PSLogisticRegression.train(sc, training, numIterations, stepSize, miniBatchFraction) + + + val prediction = model.predict(test.map(_.features)) + val predictionAndLabel = prediction.zip(test.map(_.label)) + + val metrics = new BinaryClassificationMetrics(predictionAndLabel) + + println(s"Test areaUnderPR = ${metrics.areaUnderPR()}.") + println(s"Test areaUnderROC = ${metrics.areaUnderROC()}.") + + sc.stop() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/PSLogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/PSLogisticRegression.scala new file mode 100644 index 000000000000..74b5ccd21610 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/PSLogisticRegression.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.classification + + +import org.apache.spark.mllib.linalg.BLAS._ +import org.apache.spark.ps.PSContext +import org.apache.spark.util.random.BernoulliSampler +import org.apache.spark.SparkContext +import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD + +object PSLogisticRegression { + + def train( + sc: SparkContext, + input: RDD[LabeledPoint], + numIterations: Int, + stepSize: Double, + miniBatchFraction: Double): LogisticRegressionModel = { + val numFeatures = input.map(_.features.size).first() + val initialWeights = new Array[Double](numFeatures) + + + val pssc = new PSContext(sc) + val initialModelRDD = sc.parallelize(Array(("w", initialWeights)), 1) + pssc.loadPSModel(initialModelRDD) + + input.runWithPS(2, (arr, client) => { + val sampler = new BernoulliSampler[LabeledPoint](miniBatchFraction) + for (i <- 0 to numIterations) { + val weights = Vectors.dense(client.get("w")) + + sampler.setSeed(i + 42) + sampler.sample(arr.toIterator).foreach { point => + val data = point.features + val label = point.label + val margin = -1.0 * dot(data, weights) + val multiplier = (1.0 / (1.0 + math.exp(margin))) - label + val delta = Vectors.dense(new Array[Double](numFeatures)) + axpy((-1) * stepSize / math.sqrt(i + 1) * multiplier, data, delta) + client.update("w", delta.toArray) + } + + client.clock() + } + }) + + val weights = Vectors.dense(pssc.downloadPSModel(Array("w"), 1)(0)) + val intercept = 0.0 + + new LogisticRegressionModel(weights, intercept).clearThreshold() + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index 8a36c6810790..db5c345d3f29 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -49,7 +49,6 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi }.foreach(println) */ - override def eval(input: Row): Any = { val result = children.size match { case 0 => function.asInstanceOf[() => Any]() diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 418f9048d656..5829b4c49047 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -90,6 +90,12 @@ private[spark] class ApplicationMaster( // Set the master property to match the requested mode. System.setProperty("spark.master", "yarn-cluster") + // Set if need to run a job on `Parameter Server`. + System.setProperty("spark.enablePS", args.enablePS.toString) + + // Set the number of servers. + System.setProperty("spark.num.servers", args.numPSServers.toString) + // Propagate the application ID so that YarnClusterSchedulerBackend can pick it up. System.setProperty("spark.yarn.app.id", appAttemptId.getApplicationId().toString()) } @@ -224,7 +230,9 @@ private[spark] class ApplicationMaster( .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}" } .getOrElse("") - allocator = client.register(yarnConf, + val enablePS = args.enablePS + allocator = client.register(enablePS, + yarnConf, if (sc != null) sc.getConf else sparkConf, if (sc != null) sc.preferredNodeLocationData else Map(), uiAddress, diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index e1a992af3aae..2c2df8a75798 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -27,9 +27,13 @@ class ApplicationMasterArguments(val args: Array[String]) { var primaryPyFile: String = null var pyFiles: String = null var userArgs: Seq[String] = Seq[String]() + var enablePS: Boolean = false var executorMemory = 1024 + var psServerMemory = 1024 var executorCores = 1 + var psServerCores = 1 var numExecutors = DEFAULT_NUMBER_EXECUTORS + var numPSServers = DEFAULT_NUMBER_PS_SERVERS parseArgs(args.toList) @@ -62,6 +66,10 @@ class ApplicationMasterArguments(val args: Array[String]) { userArgsBuffer += value args = tail + case ("--enablePS") :: value :: tail => + enablePS = value.toBoolean + args = tail + case ("--num-workers" | "--num-executors") :: IntParam(value) :: tail => numExecutors = value args = tail @@ -74,6 +82,18 @@ class ApplicationMasterArguments(val args: Array[String]) { executorCores = value args = tail + case ("--num-servers") :: IntParam(value) :: tail => + numPSServers = value + args = tail + + case ("--server-memory") :: MemoryParam(value) :: tail => + psServerMemory = value + args = tail + + case ("--server-cores") :: IntParam(value) :: tail => + psServerCores = value + args = tail + case _ => printUsageAndExit(1, args) } @@ -96,9 +116,13 @@ class ApplicationMasterArguments(val args: Array[String]) { | place on the PYTHONPATH for Python apps. | --args ARGS Arguments to be passed to your application's main class. | Multiple invocations are possible, each will be passed in order. + | --enablePS enable parameter server | --num-executors NUM Number of executors to start (Default: 2) | --executor-cores NUM Number of cores for the executors (Default: 1) | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G) + | --num-servers NUM Number of servers to start (Default: 1) + | --server-cores NUM Number of cores for the server (Default: 1) + | --server-memory MEM Memory per server (e.g. 1000M, 2G) (Default: 1G) """.stripMargin) System.exit(exitCode) } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 61f8fc3f5a01..ff9b6cd6386c 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -505,7 +505,11 @@ private[spark] class Client( Seq( "--executor-memory", args.executorMemory.toString + "m", "--executor-cores", args.executorCores.toString, - "--num-executors ", args.numExecutors.toString) + "--num-executors ", args.numExecutors.toString, + "--server-memory", args.psServerMemory + "m", + "--server-cores", args.psServerCores.toString, + "--num-servers", args.numPSservers.toString, + "--enablePS", args.enablePS.toString) // Command for the ApplicationMaster val commands = prefixEnv ++ Seq( diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index c1d3f7320f53..653906246866 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -42,6 +42,7 @@ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.network.util.JavaUtils class ExecutorRunnable( + isPSServer: Boolean, container: Container, conf: Configuration, sparkConf: SparkConf, @@ -202,6 +203,12 @@ class ExecutorRunnable( Seq("--user-class-path", "file:" + absPath) }.toSeq + val workerBackend = if (isPSServer) { + "org.apache.spark.ps.CoarseGrainedParameterServerBackend" + } else { + "org.apache.spark.executor.CoarseGrainedExecutorBackend" + } + val commands = prefixEnv ++ Seq( YarnSparkHadoopUtil.expandEnvironment(Environment.JAVA_HOME) + "/bin/java", "-server", @@ -212,7 +219,7 @@ class ExecutorRunnable( // 'something' to fail job ... akin to blacklisting trackers in mapred ? "-XX:OnOutOfMemoryError='kill %p'") ++ javaOpts ++ - Seq("org.apache.spark.executor.CoarseGrainedExecutorBackend", + Seq(workerBackend, "--driver-url", masterAddress.toString, "--executor-id", slaveId.toString, "--hostname", hostname.toString, diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index c98763e15b58..ffe691d1034c 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -53,6 +53,7 @@ import org.apache.spark.util.AkkaUtils * synchronized. */ private[yarn] class YarnAllocator( + enablePS: Boolean, conf: Configuration, sparkConf: SparkConf, amClient: AMRMClient[ContainerRequest], @@ -71,6 +72,7 @@ private[yarn] class YarnAllocator( // Visible for testing. val allocatedHostToContainersMap = new HashMap[String, collection.mutable.Set[ContainerId]] val allocatedContainerToHostMap = new HashMap[ContainerId, String] + val allocatedContainerForPSServer = new HashMap[ContainerId, String] // Containers that we no longer care about. We've either already told the RM to release them or // will on the next heartbeat. Containers get removed from this map after the RM tells us they've @@ -79,11 +81,13 @@ private[yarn] class YarnAllocator( new ConcurrentHashMap[ContainerId, java.lang.Boolean]) @volatile private var numExecutorsRunning = 0 + @volatile private var numPSServersRunning = 0 // Used to generate a unique ID per executor private var executorIdCounter = 0 @volatile private var numExecutorsFailed = 0 @volatile private var targetNumExecutors = args.numExecutors + @volatile private var targetNumServers = args.numPSServers // Keep track of which container is running which executor to remove the executors later // Visible for testing. @@ -91,13 +95,19 @@ private[yarn] class YarnAllocator( // Executor memory in MB. protected val executorMemory = args.executorMemory + // PS Server memory in MB. + protected val psServerMemory = args.psServerMemory // Additional memory overhead. protected val memoryOverhead: Int = sparkConf.getInt("spark.yarn.executor.memoryOverhead", math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toInt, MEMORY_OVERHEAD_MIN)) // Number of cores per executor. protected val executorCores = args.executorCores + // Number of cores per server. + protected val psServerCores = args.psServerCores // Resource capability requested for each executors - private val resource = Resource.newInstance(executorMemory + memoryOverhead, executorCores) + private val resource4Executor = Resource.newInstance(executorMemory + memoryOverhead, executorCores) + // Resource capability requested for each server + private val resource4PSServer = Resource.newInstance(psServerMemory + memoryOverhead, psServerCores) private val launcherPool = new ThreadPoolExecutor( // max pool size of Integer.MAX_VALUE is ignored because we use an unbounded queue @@ -121,15 +131,17 @@ private[yarn] class YarnAllocator( def getNumExecutorsFailed: Int = numExecutorsFailed + def getNumPSServersRunning: Int = numPSServersRunning + /** * Number of container requests that have not yet been fulfilled. */ - def getNumPendingAllocate: Int = getNumPendingAtLocation(ANY_HOST) + def getNumPendingAllocate: Int = getNumPendingAtLocation(ANY_HOST, resource4Executor) /** * Number of container requests at the given location that have not yet been fulfilled. */ - private def getNumPendingAtLocation(location: String): Int = + private def getNumPendingAtLocation(location: String, resource: Resource): Int = amClient.getMatchingRequests(RM_REQUEST_PRIORITY, location, resource).map(_.size).sum /** @@ -151,7 +163,15 @@ private[yarn] class YarnAllocator( if (executorIdToContainer.contains(executorId)) { val container = executorIdToContainer.remove(executorId).get internalReleaseContainer(container) - numExecutorsRunning -= 1 + if (allocatedContainerForPSServer.contains(container.getId)) { + numPSServersRunning -= 1 + targetNumServers -= 1 + assert(targetNumServers >= 0, "Allocator killed more servers than are allocated!") + } else { + numExecutorsRunning -= 1 + targetNumExecutors -= 1 + assert(targetNumExecutors >= 0, "Allocator killed more executors than are allocated!") + } } else { logWarning(s"Attempted to kill unknown executor $executorId!") } @@ -166,7 +186,15 @@ private[yarn] class YarnAllocator( * This must be synchronized because variables read in this method are mutated by other methods. */ def allocateResources(): Unit = synchronized { - updateResourceRequests() + val numPendingAllocate = getNumPendingAllocate + val missing = + if (enablePS) { + targetNumExecutors + targetNumServers - numPendingAllocate - numExecutorsRunning - numPSServersRunning + } else { + targetNumExecutors - numPendingAllocate - numExecutorsRunning + } + + updateResourceRequests(resource4Executor, missing, numPendingAllocate) val progressIndicator = 0.1f // Poll the ResourceManager. This doubles as a heartbeat if there are no pending container @@ -182,7 +210,40 @@ private[yarn] class YarnAllocator( numExecutorsRunning, allocateResponse.getAvailableResources)) - handleAllocatedContainers(allocatedContainers) + handleAllocatedContainers(allocatedContainers, resource4Executor, executorMemory, executorCores) + } + + val completedContainers = allocateResponse.getCompletedContainersStatuses() + if (completedContainers.size > 0) { + logDebug("Completed %d containers".format(completedContainers.size)) + + processCompletedContainers(completedContainers) + + logDebug("Finished processing %d completed containers. Current running executor count: %d." + .format(completedContainers.size, numExecutorsRunning)) + } + } + + def allocateServerResources(): Unit = synchronized { + val numPendingAllocate = getNumPendingAtLocation(ANY_HOST, resource4PSServer) + val missing = targetNumServers - numPendingAllocate - numPSServersRunning + updateResourceRequests(resource4PSServer, missing, numPendingAllocate) + + val progressIndicator = 0.1f + // Poll the ResourceManager. This doubles as a heartbeat if there are no pending container + // requests. + val allocateResponse = amClient.allocate(progressIndicator) + + val allocatedContainers = allocateResponse.getAllocatedContainers() + + if (allocatedContainers.size > 0) { + logDebug("Allocated containers: %d. Current executor count: %d. Cluster resources: %s." + .format( + allocatedContainers.size, + numPSServersRunning, + allocateResponse.getAvailableResources)) + + handleAllocatedContainers(allocatedContainers, resource4PSServer, psServerMemory, psServerCores) } val completedContainers = allocateResponse.getCompletedContainersStatuses() @@ -202,7 +263,7 @@ private[yarn] class YarnAllocator( * * Visible for testing. */ - def updateResourceRequests(): Unit = { + def updateResourceRequests(resource: Resource, missing: Int, numPendingAllocate: Int): Unit = { val numPendingAllocate = getNumPendingAllocate val missing = targetNumExecutors - numPendingAllocate - numExecutorsRunning @@ -238,13 +299,17 @@ private[yarn] class YarnAllocator( * * Visible for testing. */ - def handleAllocatedContainers(allocatedContainers: Seq[Container]): Unit = { + def handleAllocatedContainers( + allocatedContainers: Seq[Container], + resource: Resource, + executorMemory: Int, + executorCores: Int ): Unit = { val containersToUse = new ArrayBuffer[Container](allocatedContainers.size) // Match incoming requests by host val remainingAfterHostMatches = new ArrayBuffer[Container] for (allocatedContainer <- allocatedContainers) { - matchContainerToRequest(allocatedContainer, allocatedContainer.getNodeId.getHost, + matchContainerToRequest(resource, allocatedContainer, allocatedContainer.getNodeId.getHost, containersToUse, remainingAfterHostMatches) } @@ -252,14 +317,14 @@ private[yarn] class YarnAllocator( val remainingAfterRackMatches = new ArrayBuffer[Container] for (allocatedContainer <- remainingAfterHostMatches) { val rack = RackResolver.resolve(conf, allocatedContainer.getNodeId.getHost).getNetworkLocation - matchContainerToRequest(allocatedContainer, rack, containersToUse, + matchContainerToRequest(resource, allocatedContainer, rack, containersToUse, remainingAfterRackMatches) } // Assign remaining that are neither node-local nor rack-local val remainingAfterOffRackMatches = new ArrayBuffer[Container] for (allocatedContainer <- remainingAfterRackMatches) { - matchContainerToRequest(allocatedContainer, ANY_HOST, containersToUse, + matchContainerToRequest(resource, allocatedContainer, ANY_HOST, containersToUse, remainingAfterOffRackMatches) } @@ -271,7 +336,7 @@ private[yarn] class YarnAllocator( } } - runAllocatedContainers(containersToUse) + runAllocatedContainers(containersToUse, resource, executorMemory, executorCores) logInfo("Received %d containers from YARN, launching executors on %d of them." .format(allocatedContainers.size, containersToUse.size)) @@ -288,6 +353,7 @@ private[yarn] class YarnAllocator( * @param remaining list of containers that will not be used */ private def matchContainerToRequest( + resource: Resource, allocatedContainer: Container, location: String, containersToUse: ArrayBuffer[Container], @@ -314,10 +380,20 @@ private[yarn] class YarnAllocator( /** * Launches executors in the allocated containers. */ - private def runAllocatedContainers(containersToUse: ArrayBuffer[Container]): Unit = { + private def runAllocatedContainers( + containersToUse: ArrayBuffer[Container], + resource: Resource, + executorMemory: Int, + executorCores: Int ): Unit = { for (container <- containersToUse) { - numExecutorsRunning += 1 - assert(numExecutorsRunning <= targetNumExecutors) + val isPSServer = + if (enablePS && numPSServersRunning < targetNumServers) { + numPSServersRunning += 1 + true + } else { + numExecutorsRunning += 1 + false + } val executorHostname = container.getNodeId.getHost val containerId = container.getId executorIdCounter += 1 @@ -334,21 +410,27 @@ private[yarn] class YarnAllocator( containerSet += containerId allocatedContainerToHostMap.put(containerId, executorHostname) - val executorRunnable = new ExecutorRunnable( - container, - conf, - sparkConf, - driverUrl, - executorId, - executorHostname, - executorMemory, - executorCores, - appAttemptId.getApplicationId.toString, - securityMgr) + if (isPSServer) { + allocatedContainerForPSServer.put(containerId, executorHostname) + } + + val runnable = + new ExecutorRunnable( + isPSServer, + container, + conf, + sparkConf, + driverUrl, + executorId, + executorHostname, + executorMemory, + executorCores, + appAttemptId.getApplicationId.toString, + securityMgr) if (launchContainers) { logInfo("Launching ExecutorRunnable. driverUrl: %s, executorHostname: %s".format( driverUrl, executorHostname)) - launcherPool.execute(executorRunnable) + launcherPool.execute(runnable) } } } @@ -365,7 +447,12 @@ private[yarn] class YarnAllocator( } else { // Decrement the number of executors running. The next iteration of // the ApplicationMaster's reporting thread will take care of allocating. - numExecutorsRunning -= 1 + if (allocatedContainerForPSServer.contains(containerId)) { + allocatedContainerForPSServer.remove(containerId) + numPSServersRunning -= 1 + } else { + numExecutorsRunning -= 1 + } logInfo("Completed container %s (state: %s, exit status: %s)".format( containerId, completedContainer.getState, diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index b13475136652..75227105def6 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -55,6 +55,7 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg * @param uiHistoryAddress Address of the application on the History Server. */ def register( + enablePS: Boolean, conf: YarnConfiguration, sparkConf: SparkConf, preferredNodeLocations: Map[String, Set[SplitInfo]], @@ -72,7 +73,7 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress) registered = true } - new YarnAllocator(conf, sparkConf, amClient, getAttemptId(), args, securityMgr) + new YarnAllocator(enablePS, conf, sparkConf, amClient, getAttemptId(), args, securityMgr) } /** diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 146b2c0f1a30..ffcf81bc5885 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -95,6 +95,7 @@ object YarnSparkHadoopUtil { val ANY_HOST = "*" val DEFAULT_NUMBER_EXECUTORS = 2 + val DEFAULT_NUMBER_PS_SERVERS = 1 // All RM requests are issued with same priority : we do not (yet) have any distinction between // request types (like map/reduce in hadoop for example)