Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import java.lang.reflect.Constructor
import java.net.URI
import java.util.{Arrays, Properties, UUID}
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.CountDownLatch
import java.util.UUID.randomUUID

import scala.collection.{Map, Set}
Expand Down Expand Up @@ -196,6 +197,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
private[spark] val conf = config.clone()
conf.validateSettings()

/** Has the server been marked for start. */
val startLatch = new CountDownLatch(1)

/**
* Return a copy of this SparkContext's configuration. The configuration ''cannot'' be
* changed at runtime.
Expand Down Expand Up @@ -1441,6 +1445,36 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
}.getOrElse(Utils.getCallSite())
}

/**
* Run a function on a given set of partitions in an RDD and pass the results to the given
* handler function. This is the main entry point for all actions in Spark. The allowLocal
* flag specifies whether the scheduler can run the computation on the driver rather than
* shipping it out to the cluster, for short actions like first().
*/
def runJobWithPS[T, U: ClassTag](
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U): Array[U] = {
if (stopped) {
throw new IllegalStateException("SparkContext has been shutdown")
}
val callSite = getCallSite
val cleanedFunc = clean(func)
logInfo("Starting run parameter server job: " + callSite.shortForm)
if (conf.getBoolean("spark.logLineage", false)) {
logInfo("RDD's recursive dependencies:\n" + rdd.toDebugString)
}
val results = new Array[U](rdd.partitions.size)
val resultHandler: (Int, U) => Unit = (pid, value) => {
logInfo(s"partition number $pid, value: $value")
results(pid) = value
}
dagScheduler.runJob(rdd, cleanedFunc, 0 until rdd.partitions.size,
callSite, false, resultHandler, localProperties.get)
progressBar.foreach(_.finishAll())
rdd.doCheckpoint()
results
}

/**
* Run a function on a given set of partitions in an RDD and pass the results to the given
* handler function. This is the main entry point for all actions in Spark. The allowLocal
Expand Down
3 changes: 3 additions & 0 deletions core/src/main/scala/org/apache/spark/TaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.io.Serializable

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.ps.PSClient
import org.apache.spark.util.TaskCompletionListener


Expand Down Expand Up @@ -133,4 +134,6 @@ abstract class TaskContext extends Serializable {
/** ::DeveloperApi:: */
@DeveloperApi
def taskMetrics(): TaskMetrics

def getPSClient: PSClient
}
4 changes: 4 additions & 0 deletions core/src/main/scala/org/apache/spark/TaskContextImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark

import org.apache.spark.executor.TaskMetrics
import org.apache.spark.ps.PSClient
import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException}

import scala.collection.mutable.ArrayBuffer
Expand All @@ -27,6 +28,7 @@ private[spark] class TaskContextImpl(
val partitionId: Int,
override val taskAttemptId: Long,
override val attemptNumber: Int,
val psClient: Option[PSClient] = None,
val runningLocally: Boolean = false,
val taskMetrics: TaskMetrics = TaskMetrics.empty)
extends TaskContext
Expand Down Expand Up @@ -92,5 +94,7 @@ private[spark] class TaskContextImpl(
override def isRunningLocally(): Boolean = runningLocally

override def isInterrupted(): Boolean = interrupted

override def getPSClient: PSClient = psClient.get
}

21 changes: 21 additions & 0 deletions core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ private[spark] class ClientArguments(args: Array[String]) {
var supervise: Boolean = DEFAULT_SUPERVISE
var memory: Int = DEFAULT_MEMORY
var cores: Int = DEFAULT_CORES
var psServerMemory = DEFAULT_MEMORY
var psServerCores = DEFAULT_CORES
var numPSservers = DEFAULT_NUMBER_PS_SERVERS
var enablePS: Boolean = false
private var _driverOptions = ListBuffer[String]()
def driverOptions = _driverOptions.toSeq

Expand All @@ -58,6 +62,22 @@ private[spark] class ClientArguments(args: Array[String]) {
memory = value
parse(tail)

case ("--num-servers") :: IntParam(value) :: tail =>
numPSservers = value
parse(tail)

case ("--server-memory") :: MemoryParam(value) :: tail =>
psServerMemory = value
parse(tail)

case ("--server-cores") :: IntParam(value) :: tail =>
psServerCores = value
parse(tail)

case ("--enablePS") :: value :: tail =>
enablePS = value.toBoolean
parse(tail)

case ("--supervise" | "-s") :: tail =>
supervise = true
parse(tail)
Expand Down Expand Up @@ -119,6 +139,7 @@ private[spark] class ClientArguments(args: Array[String]) {
object ClientArguments {
private[spark] val DEFAULT_CORES = 1
private[spark] val DEFAULT_MEMORY = 512 // MB
private[spark] val DEFAULT_NUMBER_PS_SERVERS = 1
private[spark] val DEFAULT_SUPERVISE = false

def isValidJarUrl(s: String): Boolean = {
Expand Down
4 changes: 4 additions & 0 deletions core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -367,8 +367,12 @@ object SparkSubmit {
OptionAssigner(args.driverCores, YARN, CLUSTER, clOption = "--driver-cores"),
OptionAssigner(args.queue, YARN, CLUSTER, clOption = "--queue"),
OptionAssigner(args.numExecutors, YARN, CLUSTER, clOption = "--num-executors"),
OptionAssigner(args.numPSServers, YARN, CLUSTER, clOption = "--num-servers"),
OptionAssigner(args.executorMemory, YARN, CLUSTER, clOption = "--executor-memory"),
OptionAssigner(args.executorCores, YARN, CLUSTER, clOption = "--executor-cores"),
OptionAssigner(args.psServerMemory, YARN, CLUSTER, clOption = "--server-memory"),
OptionAssigner(args.psServerCores, YARN, CLUSTER, clOption = "--server-cores"),
OptionAssigner(args.enablePS.toString, YARN, CLUSTER, clOption = "--enablePS"),
OptionAssigner(args.files, YARN, CLUSTER, clOption = "--files"),
OptionAssigner(args.archives, YARN, CLUSTER, clOption = "--archives"),
OptionAssigner(args.jars, YARN, CLUSTER, clOption = "--addJars"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
var deployMode: String = null
var executorMemory: String = null
var executorCores: String = null
var psServerMemory: String = null
var psServerCores: String = null
var totalExecutorCores: String = null
var propertiesFile: String = null
var driverMemory: String = null
Expand All @@ -42,6 +44,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
var driverExtraJavaOptions: String = null
var queue: String = null
var numExecutors: String = null
var numPSServers: String = null
var files: String = null
var archives: String = null
var mainClass: String = null
Expand All @@ -52,6 +55,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
var packages: String = null
var repositories: String = null
var ivyRepoPath: String = null
var enablePS: Boolean = false
var verbose: Boolean = false
var isPython: Boolean = false
var pyFiles: String = null
Expand Down Expand Up @@ -315,6 +319,10 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
numExecutors = value
parse(tail)

case ("--num-servers") :: value :: tail =>
numPSServers = value
parse(tail)

case ("--total-executor-cores") :: value :: tail =>
totalExecutorCores = value
parse(tail)
Expand All @@ -327,6 +335,14 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
executorMemory = value
parse(tail)

case ("--server-cores") :: value :: tail =>
psServerCores = value
parse(tail)

case ("--server-memory") :: value :: tail =>
psServerMemory = value
parse(tail)

case ("--driver-memory") :: value :: tail =>
driverMemory = value
parse(tail)
Expand Down Expand Up @@ -417,6 +433,10 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
verbose = true
parse(tail)

case ("--enablePS") :: tail =>
enablePS = true
parse(tail)

case ("--version") :: tail =>
SparkSubmit.printVersionAndExit()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

package org.apache.spark.executor

import java.io._
import java.net.URL
import java.nio.ByteBuffer
import java.lang.management.ManagementFactory

import scala.collection.mutable
import scala.concurrent.Await
Expand All @@ -31,9 +33,15 @@ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv}
import org.apache.spark.TaskState.TaskState
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.deploy.worker.WorkerWatcher
import org.apache.spark.ps.CoarseGrainedParameterServerMessage.NotifyClient
import org.apache.spark.ps.{PSClient, ServerInfo}
import org.apache.spark.scheduler.TaskDescription
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils}
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.KillTask
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RegisterExecutor
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RegisterExecutorFailed
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.LaunchTask
import org.apache.spark.util._

private[spark] class CoarseGrainedExecutorBackend(
driverUrl: String,
Expand All @@ -48,6 +56,8 @@ private[spark] class CoarseGrainedExecutorBackend(

var executor: Executor = null
var driver: ActorSelection = null
val psServers = new mutable.HashMap[Long, ServerInfo]()
var psClient: Option[PSClient] = None

override def preStart() {
logInfo("Connecting to driver: " + driverUrl)
Expand All @@ -72,6 +82,22 @@ private[spark] class CoarseGrainedExecutorBackend(
logError("Slave registration failed: " + message)
System.exit(1)

case AddNewPSServer(serverInfo) =>
println("Adding new ps server")
psClient.synchronized {
if (!psClient.isDefined) {
val clientId: String = executorId
psClient = Some(new PSClient(clientId, context, env.conf))
}
psClient.get.addServer(serverInfo)
val serverId = serverInfo.serverId
psServers(serverId) = serverInfo
}

case NotifyClient(message) =>
println(s"message:$message")
psClient.get.notifyTasks()

case LaunchTask(data) =>
if (executor == null) {
logError("Received LaunchTask command but executor was null")
Expand Down Expand Up @@ -110,6 +136,8 @@ private[spark] class CoarseGrainedExecutorBackend(
override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {
driver ! StatusUpdate(executorId, taskId, state, data)
}

override def getPSClient: Option[PSClient] = psClient
}

private[spark] object CoarseGrainedExecutorBackend extends Logging {
Expand Down
3 changes: 2 additions & 1 deletion core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ private[spark] class Executor(

// Run the actual task and measure its runtime.
taskStart = System.currentTimeMillis()
val value = task.run(taskAttemptId = taskId, attemptNumber = attemptNumber)
val psClient = execBackend.getPSClient
val value = task.run(psClient, taskAttemptId = taskId, attemptNumber = attemptNumber)
val taskFinish = System.currentTimeMillis()

// If the task has been killed, let's fail it.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@ package org.apache.spark.executor
import java.nio.ByteBuffer

import org.apache.spark.TaskState.TaskState
import org.apache.spark.ps.PSClient

/**
* A pluggable interface used by the Executor to send updates to the cluster scheduler.
*/
private[spark] trait ExecutorBackend {
def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer)

def getPSClient: Option[PSClient]
}

Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _}
import org.apache.spark.{Logging, TaskState, SparkConf, SparkEnv}
import org.apache.spark.TaskState.TaskState
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.ps.PSClient
import org.apache.spark.scheduler.cluster.mesos.{MesosTaskLaunchData}
import org.apache.spark.util.{SignalLogger, Utils}

Expand All @@ -48,6 +49,8 @@ private[spark] class MesosExecutorBackend
.build())
}

override def getPSClient: Option[PSClient] = None

override def registered(
driver: ExecutorDriver,
executorInfo: ExecutorInfo,
Expand Down
Loading