Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Cancellation of Spark Jobs #665

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions core/src/main/scala/spark/CacheManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
case Some(cachedValues) =>
// Partition is in cache, so just return its values
logInfo("Found partition in cache!")
return cachedValues.asInstanceOf[Iterator[T]]
val iter = cachedValues.asInstanceOf[Iterator[T]]
return new InterruptibleIteratorDecorator(iter)

case None =>
// Mark the split as loading (unless someone else marks it first)
Expand All @@ -37,7 +38,8 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
// downside of the current code is that threads wait serially if this does happen.
blockManager.get(key) match {
case Some(values) =>
return values.asInstanceOf[Iterator[T]]
val iter = values.asInstanceOf[Iterator[T]]
return new InterruptibleIteratorDecorator(iter)
case None =>
logInfo("Whoever was loading " + key + " failed; we'll try it ourselves")
loading.add(key)
Expand All @@ -53,7 +55,8 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
elements ++= rdd.computeOrReadCheckpoint(split, context)
// Try to put this block in the blockManager
blockManager.put(key, elements, storageLevel, true)
return elements.iterator.asInstanceOf[Iterator[T]]
val iter = elements.iterator.asInstanceOf[Iterator[T]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

excess indentation

return new InterruptibleIteratorDecorator(iter)
} finally {
loading.synchronized {
loading.remove(key)
Expand Down
26 changes: 26 additions & 0 deletions core/src/main/scala/spark/InterruptibleIterator.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package spark

trait InterruptibleIterator[+T] extends Iterator[T]{

override def hasNext(): Boolean = {
if (!Thread.currentThread().isInterrupted()) {
true
} else {
throw new InterruptedException ("Thread interrupted during RDD iteration")
}
}

}

class InterruptibleIteratorDecorator[T](delegate: Iterator[T])
extends AnyRef with InterruptibleIterator[T] {

override def hasNext(): Boolean = {
super.hasNext
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems kinda weird to me that InterruptibleIterator is actually implementing hasNext, which then you override here. Maybe it should have a method exceptionIfThreadInterrupted. It seems like the trait is not actually implementing hasNext at all, its just supplying a utility methods for implementations.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this seems odd to me, too. Is there some reason why InterruptibleIterator can't be implemented more like CompletionIterator? In other words, class InterruptibleIteratorDecorator just becomes class InterruptibleIterator, which extends Iterator; and your InterruptibleIterator trait becomes object InterruptibleIterator, which defs a function called something like notInterrupted to replace all of the calls to InterruptibleIterator.hasNext -- i.e. the peculiar super.hasNext calls become something like InterruptibleIterator.notInterrupted.

delegate.hasNext
}

override def next(): T = {
delegate.next()
}
}
9 changes: 7 additions & 2 deletions core/src/main/scala/spark/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,12 @@ abstract class RDD[T: ClassManifest](
// =======================================================================

/** Implemented by subclasses to compute a given partition. */
def compute(split: Partition, context: TaskContext): Iterator[T]
protected def compute(split: Partition, context: TaskContext): Iterator[T]

def computeInterruptibly(split: Partition, context: TaskContext): Iterator[T] = {
val iter = compute(split, context)
new InterruptibleIteratorDecorator(iter)
}

/**
* Implemented by subclasses to return the set of partitions in this RDD. This method will only
Expand Down Expand Up @@ -229,7 +234,7 @@ abstract class RDD[T: ClassManifest](
if (isCheckpointed) {
firstParent[T].iterator(split, context)
} else {
compute(split, context)
computeInterruptibly(split, context)
}
}

Expand Down
4 changes: 4 additions & 0 deletions core/src/main/scala/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,10 @@ class SparkContext(
}
}

def killJob(jobId: Int, reason: String="") {
dagScheduler.killJob(jobId, reason)
}

/**
* Run a function on a given set of partitions in an RDD and pass the results to the given
* handler function. This is the main entry point for all actions in Spark. The allowLocal
Expand Down
28 changes: 21 additions & 7 deletions core/src/main/scala/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@ package spark.executor

import java.io.{File, FileOutputStream}
import java.net.{URI, URL, URLClassLoader}
import java.nio.ByteBuffer
import java.util.concurrent._

import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer,ConcurrentMap, HashMap, Map}
import scala.concurrent.JavaConversions._
import org.apache.hadoop.fs.FileUtil

import scala.collection.mutable.{ArrayBuffer, Map, HashMap}

import spark.broadcast._
import spark.scheduler._
import spark._
import java.nio.ByteBuffer
import spark.scheduler.cluster.TaskDescription

/**
* The Mesos executor for Spark.
Expand Down Expand Up @@ -79,14 +79,26 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
val threadPool = new ThreadPoolExecutor(
1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])

val tasks: ConcurrentMap[Long, FutureTask[_]] = new ConcurrentHashMap[Long, FutureTask[_]]()
def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) {
threadPool.execute(new TaskRunner(context, taskId, serializedTask))
val runner = new TaskRunner(context, taskId, serializedTask)
val task = threadPool.submit(runner).asInstanceOf[FutureTask[_]]
tasks.put(taskId, task)

}

def killTask(context: ExecutorBackend, taskId: Long, executorId: String) {
val task = tasks.get(taskId)
task match {
case Some(t) => t.cancel(true)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be good to do tasks.remove(taskId) here too in case the task thread had never started and so we never got to the "finally" in its run() method.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, will do thanks

case None =>
}
}

class TaskRunner(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer)
extends Runnable {

override def run() {
override def run(): Unit = {
val startTime = System.currentTimeMillis()
SparkEnv.set(env)
Thread.currentThread.setContextClassLoader(replClassLoader)
Expand Down Expand Up @@ -138,6 +150,8 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
logError("Exception in task ID " + taskId, t)
//System.exit(1)
}
} finally {
tasks.remove(taskId)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ private[spark] class StandaloneExecutorBackend(
} else {
executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask)
}

case KillTask(taskId, executorId) =>
logInfo("Killing Task %s %s".format(taskId, executorId))
if (executor != null) {
executor.killTask(this, taskId, executorId)
}

case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) =>
logError("Driver terminated or disconnected! Shutting down.")
Expand Down
41 changes: 40 additions & 1 deletion core/src/main/scala/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,41 @@ class DAGScheduler(
return listener.awaitResult() // Will throw an exception if the job fails
}

def killJob(jobId: Int, reason: String)
{
logInfo("Killing Job %s".format(jobId))
val j = activeJobs.find(j => j.runId.equals(jobId))
j match {
case Some(job) => killJob(job, reason)
case None => Unit
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

j.map(killJob(_, reason)) instead of pattern match boilerplate with a noop None case. Similarly for other instances where you only want to do something when you have Some(thing).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, thanks...did not know this pattern

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

http://ymasory.github.io/error-handling-in-scala/#slide-52

Not everything that Yuvi presents there can be used in pre-2.10 Scala.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, and in this particular case where killJob evaluates to Unit, j.foreach(killJob(_, reason)) is a little clearer than the map. Use map when you actually want to produce Some(result).

}

private def killJob(job: ActiveJob, reason: String) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mateiz Doesn't this have the same problem discussed in #414 where more than one ActiveJob can share a stage?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that is actually true. To do this properly we'll need to do some kind of reference-counting on the stages (keep a list of which jobs currently want to run this stage). One difference here is that killJob is called by the user and for the first use case, of Shark, it's probably going to be fine. But it would be good to either track this properly or send a warning.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's pretty much the conclusion that I was arriving at. I'll work on the reference-counting refactoring. Should be doable independently of this PR and only require a minimal change here once it is done.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, that would be great to have.

logInfo("Killing Job and cleaning up stages %s".format(job.runId))
activeJobs.remove(job)
idToActiveJob.remove(job.runId)
val stage = job.finalStage
resultStageToJob.remove(stage)
killStage(stage)
// recursively remove all parent stages
stage.parents.foreach(p => killStage(p))
job.listener.jobFailed(new SparkException("Job failed: " + reason))
}

private def killStage(stage: Stage) {
logInfo("Killing Stage %s".format(stage.id))
idToStage.remove(stage.id)
if (stage.isShuffleMap) {
shuffleToMapStage.remove(stage.id)
}
waiting.remove(stage)
pendingTasks.remove(stage)
running.remove(stage)
taskSched.killTasks(stage.id)
stage.parents.foreach(p => killStage(p))
}

/**
* Process one event retrieved from the event queue.
* Returns true if we should stop the event loop.
Expand Down Expand Up @@ -495,7 +530,11 @@ class DAGScheduler(
*/
private def handleTaskCompletion(event: CompletionEvent) {
val task = event.task
val stage = idToStage(task.stageId)
val stageId = task.stageId
if (!idToStage.contains(stageId)) {
return;
}
val stage = idToStage(stageId)

def markStageAsFinished(stage: Stage) = {
val serviceTime = stage.submissionTime match {
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/spark/scheduler/ResultTask.scala
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ private[spark] class ResultTask[T, U](
preferredLocs.foreach (hostPort => Utils.checkHost(Utils.parseHostPort(hostPort)._1, "preferredLocs : " + preferredLocs))
}

override def run(attemptId: Long): U = {
override def runInterruptibly(attemptId: Long): U = {
val context = new TaskContext(stageId, partition, attemptId)
metrics = Some(context.taskMetrics)
try {
Expand Down
8 changes: 6 additions & 2 deletions core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ private[spark] class ShuffleMapTask(
rdd.partitions(partition)
}

override def kill() {
logDebug("Killing Task %s %s".format(rdd.id, partition))
super.kill()
}

override def writeExternal(out: ObjectOutput) {
RDDCheckpointData.synchronized {
split = rdd.partitions(partition)
Expand All @@ -124,7 +129,7 @@ private[spark] class ShuffleMapTask(
split = in.readObject().asInstanceOf[Partition]
}

override def run(attemptId: Long): MapStatus = {
override def runInterruptibly(attemptId: Long): MapStatus = {
val numOutputSplits = dep.partitioner.numPartitions

val taskContext = new TaskContext(stageId, partition, attemptId)
Expand All @@ -133,7 +138,6 @@ private[spark] class ShuffleMapTask(
val blockManager = SparkEnv.get.blockManager
var shuffle: ShuffleBlocks = null
var buckets: ShuffleWriterGroup = null

try {
// Obtain all the block writers for shuffle blocks.
val ser = SparkEnv.get.serializerManager.get(dep.serializerClass)
Expand Down
33 changes: 29 additions & 4 deletions core/src/main/scala/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
@@ -1,23 +1,48 @@
package spark.scheduler

import spark.serializer.SerializerInstance
import java.io.{DataInputStream, DataOutputStream}
import java.nio.ByteBuffer
import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
import spark.util.ByteBufferInputStream
import java.util.concurrent.{Callable, ExecutionException, Future, FutureTask}
import scala.collection.mutable.HashMap
import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
import spark.executor.TaskMetrics
import spark.serializer.SerializerInstance
import spark.util.ByteBufferInputStream


/**
* A task to execute on a worker node.
*/
private[spark] abstract class Task[T](val stageId: Int) extends Serializable {
def run(attemptId: Long): T
@volatile @transient var f: FutureTask[T] = null

def run(attemptId: Long): T = {
f = new FutureTask(new Callable[T] {
def call(): T = {
runInterruptibly(attemptId)
}
})
try {
f.run()
f.get()
} catch {
case e: Exception => throw e.getCause()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be catching Throwable here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, you are right., we should be catching throwable here.

}
}

def runInterruptibly(attemptId: Long): T

def preferredLocations: Seq[String] = Nil

var generation: Long = -1 // Map output tracker generation. Will be set by TaskScheduler.

var metrics: Option[TaskMetrics] = None

def kill(): Unit = {
if (f != null) {
f.cancel(true)
}
}

}

Expand Down
2 changes: 2 additions & 0 deletions core/src/main/scala/spark/scheduler/TaskScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ private[spark] trait TaskScheduler {

// Submit a sequence of tasks to run.
def submitTasks(taskSet: TaskSet): Unit

def killTasks(stageId: Int): Unit

// Set a listener for upcalls. This is guaranteed to be set before submitTasks is called.
def setListener(listener: TaskSchedulerListener): Unit
Expand Down
6 changes: 6 additions & 0 deletions core/src/main/scala/spark/scheduler/TaskSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,10 @@ private[spark] class TaskSet(
val id: String = stageId + "." + attempt

override def toString: String = "TaskSet " + id

def kill() = {
tasks.foreach {
_.kill()
}
}
}
18 changes: 18 additions & 0 deletions core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,24 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
backend.reviveOffers()
}

override def killTasks(stageId: Int) {
synchronized {
schedulableBuilder.popTaskSetManagers(stageId).foreach {
t =>
val ts = t.asInstanceOf[TaskSetManager].taskSet
ts.kill()
val taskIds = taskSetTaskIds(ts.id)
taskIds.foreach {
tid =>
val execId = taskIdToExecutorId(tid)
backend.killTask(tid, execId)
}
}

}

}

def taskSetFinished(manager: TaskSetManager) {
this.synchronized {
activeTaskSets -= manager.taskSet.id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.util.Properties
private[spark] trait SchedulableBuilder {
def buildPools()
def addTaskSetManager(manager: Schedulable, properties: Properties)
def popTaskSetManagers(stageId: Int): Iterable[Schedulable]
}

private[spark] class FIFOSchedulableBuilder(val rootPool: Pool) extends SchedulableBuilder with Logging {
Expand All @@ -33,6 +34,16 @@ private[spark] class FIFOSchedulableBuilder(val rootPool: Pool) extends Schedula
override def addTaskSetManager(manager: Schedulable, properties: Properties) {
rootPool.addSchedulable(manager)
}

override def popTaskSetManagers(stageId: Int) = {
val s = rootPool.schedulableNameToSchedulable.values.filter {
_.stageId == stageId
}
s.foreach {
rootPool.removeSchedulable(_)
}
s
}
}

private[spark] class FairSchedulableBuilder(val rootPool: Pool) extends SchedulableBuilder with Logging {
Expand Down Expand Up @@ -112,4 +123,14 @@ private[spark] class FairSchedulableBuilder(val rootPool: Pool) extends Schedula
parentPool.addSchedulable(manager)
logInfo("Added task set " + manager.name + " tasks to pool "+poolName)
}

override def popTaskSetManagers(stageId: Int) = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you avoid repeating this method, eg. move to a base class?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point . thanks will do

val s = rootPool.schedulableNameToSchedulable.values.filter {
_.stageId == stageId
}
s.foreach {
rootPool.removeSchedulable(_)
}
s
}
}
Loading