diff --git a/core/src/main/scala/spark/CacheManager.scala b/core/src/main/scala/spark/CacheManager.scala index f7a2b7e802..9af6bd726c 100644 --- a/core/src/main/scala/spark/CacheManager.scala +++ b/core/src/main/scala/spark/CacheManager.scala @@ -19,7 +19,8 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { case Some(cachedValues) => // Partition is in cache, so just return its values logInfo("Found partition in cache!") - return cachedValues.asInstanceOf[Iterator[T]] + val iter = cachedValues.asInstanceOf[Iterator[T]] + return new InterruptibleIteratorDecorator(iter) case None => // Mark the split as loading (unless someone else marks it first) @@ -37,7 +38,8 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { // downside of the current code is that threads wait serially if this does happen. blockManager.get(key) match { case Some(values) => - return values.asInstanceOf[Iterator[T]] + val iter = values.asInstanceOf[Iterator[T]] + return new InterruptibleIteratorDecorator(iter) case None => logInfo("Whoever was loading " + key + " failed; we'll try it ourselves") loading.add(key) @@ -53,7 +55,8 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { elements ++= rdd.computeOrReadCheckpoint(split, context) // Try to put this block in the blockManager blockManager.put(key, elements, storageLevel, true) - return elements.iterator.asInstanceOf[Iterator[T]] + val iter = elements.iterator.asInstanceOf[Iterator[T]] + return new InterruptibleIteratorDecorator(iter) } finally { loading.synchronized { loading.remove(key) diff --git a/core/src/main/scala/spark/InterruptibleIterator.scala b/core/src/main/scala/spark/InterruptibleIterator.scala new file mode 100644 index 0000000000..b853bd47eb --- /dev/null +++ b/core/src/main/scala/spark/InterruptibleIterator.scala @@ -0,0 +1,26 @@ +package spark + +trait InterruptibleIterator[+T] extends Iterator[T]{ + + override def hasNext(): Boolean = { + if (!Thread.currentThread().isInterrupted()) { + true + } else { + throw new InterruptedException ("Thread interrupted during RDD iteration") + } + } + +} + +class InterruptibleIteratorDecorator[T](delegate: Iterator[T]) + extends AnyRef with InterruptibleIterator[T] { + + override def hasNext(): Boolean = { + super.hasNext + delegate.hasNext + } + + override def next(): T = { + delegate.next() + } +} \ No newline at end of file diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index f336c2ea1e..6e7c516fff 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -81,7 +81,12 @@ abstract class RDD[T: ClassManifest]( // ======================================================================= /** Implemented by subclasses to compute a given partition. */ - def compute(split: Partition, context: TaskContext): Iterator[T] + protected def compute(split: Partition, context: TaskContext): Iterator[T] + + def computeInterruptibly(split: Partition, context: TaskContext): Iterator[T] = { + val iter = compute(split, context) + new InterruptibleIteratorDecorator(iter) + } /** * Implemented by subclasses to return the set of partitions in this RDD. This method will only @@ -229,7 +234,7 @@ abstract class RDD[T: ClassManifest]( if (isCheckpointed) { firstParent[T].iterator(split, context) } else { - compute(split, context) + computeInterruptibly(split, context) } } diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 70a9d7698c..979a984c62 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -617,6 +617,10 @@ class SparkContext( } } + def killJob(jobId: Int, reason: String="") { + dagScheduler.killJob(jobId, reason) + } + /** * Run a function on a given set of partitions in an RDD and pass the results to the given * handler function. This is the main entry point for all actions in Spark. The allowLocal diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index 2bf55ea9a9..be716a73e5 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -2,16 +2,16 @@ package spark.executor import java.io.{File, FileOutputStream} import java.net.{URI, URL, URLClassLoader} +import java.nio.ByteBuffer import java.util.concurrent._ - +import scala.collection.JavaConversions._ +import scala.collection.mutable.{ArrayBuffer,ConcurrentMap, HashMap, Map} +import scala.concurrent.JavaConversions._ import org.apache.hadoop.fs.FileUtil - -import scala.collection.mutable.{ArrayBuffer, Map, HashMap} - import spark.broadcast._ import spark.scheduler._ import spark._ -import java.nio.ByteBuffer +import spark.scheduler.cluster.TaskDescription /** * The Mesos executor for Spark. @@ -79,14 +79,26 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert val threadPool = new ThreadPoolExecutor( 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable]) + val tasks: ConcurrentMap[Long, FutureTask[_]] = new ConcurrentHashMap[Long, FutureTask[_]]() def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) { - threadPool.execute(new TaskRunner(context, taskId, serializedTask)) + val runner = new TaskRunner(context, taskId, serializedTask) + val task = threadPool.submit(runner).asInstanceOf[FutureTask[_]] + tasks.put(taskId, task) + + } + + def killTask(context: ExecutorBackend, taskId: Long, executorId: String) { + val task = tasks.get(taskId) + task match { + case Some(t) => t.cancel(true) + case None => + } } class TaskRunner(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) extends Runnable { - override def run() { + override def run(): Unit = { val startTime = System.currentTimeMillis() SparkEnv.set(env) Thread.currentThread.setContextClassLoader(replClassLoader) @@ -138,6 +150,8 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert logError("Exception in task ID " + taskId, t) //System.exit(1) } + } finally { + tasks.remove(taskId) } } } diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala index ebe2ac68d8..4a7cda4337 100644 --- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala +++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala @@ -55,6 +55,12 @@ private[spark] class StandaloneExecutorBackend( } else { executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask) } + + case KillTask(taskId, executorId) => + logInfo("Killing Task %s %s".format(taskId, executorId)) + if (executor != null) { + executor.killTask(this, taskId, executorId) + } case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) => logError("Driver terminated or disconnected! Shutting down.") diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index f7d60be5db..fb773f72d3 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -278,6 +278,41 @@ class DAGScheduler( return listener.awaitResult() // Will throw an exception if the job fails } + def killJob(jobId: Int, reason: String) + { + logInfo("Killing Job %s".format(jobId)) + val j = activeJobs.find(j => j.runId.equals(jobId)) + j match { + case Some(job) => killJob(job, reason) + case None => Unit + } + } + + private def killJob(job: ActiveJob, reason: String) { + logInfo("Killing Job and cleaning up stages %s".format(job.runId)) + activeJobs.remove(job) + idToActiveJob.remove(job.runId) + val stage = job.finalStage + resultStageToJob.remove(stage) + killStage(stage) + // recursively remove all parent stages + stage.parents.foreach(p => killStage(p)) + job.listener.jobFailed(new SparkException("Job failed: " + reason)) + } + + private def killStage(stage: Stage) { + logInfo("Killing Stage %s".format(stage.id)) + idToStage.remove(stage.id) + if (stage.isShuffleMap) { + shuffleToMapStage.remove(stage.id) + } + waiting.remove(stage) + pendingTasks.remove(stage) + running.remove(stage) + taskSched.killTasks(stage.id) + stage.parents.foreach(p => killStage(p)) + } + /** * Process one event retrieved from the event queue. * Returns true if we should stop the event loop. @@ -495,7 +530,11 @@ class DAGScheduler( */ private def handleTaskCompletion(event: CompletionEvent) { val task = event.task - val stage = idToStage(task.stageId) + val stageId = task.stageId + if (!idToStage.contains(stageId)) { + return; + } + val stage = idToStage(stageId) def markStageAsFinished(stage: Stage) = { val serviceTime = stage.submissionTime match { diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala index 83166bce22..f302d6dbd4 100644 --- a/core/src/main/scala/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/spark/scheduler/ResultTask.scala @@ -77,7 +77,7 @@ private[spark] class ResultTask[T, U]( preferredLocs.foreach (hostPort => Utils.checkHost(Utils.parseHostPort(hostPort)._1, "preferredLocs : " + preferredLocs)) } - override def run(attemptId: Long): U = { + override def runInterruptibly(attemptId: Long): U = { val context = new TaskContext(stageId, partition, attemptId) metrics = Some(context.taskMetrics) try { diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index 95647389c3..d08e749eec 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -98,6 +98,11 @@ private[spark] class ShuffleMapTask( rdd.partitions(partition) } + override def kill() { + logDebug("Killing Task %s %s".format(rdd.id, partition)) + super.kill() + } + override def writeExternal(out: ObjectOutput) { RDDCheckpointData.synchronized { split = rdd.partitions(partition) @@ -124,7 +129,7 @@ private[spark] class ShuffleMapTask( split = in.readObject().asInstanceOf[Partition] } - override def run(attemptId: Long): MapStatus = { + override def runInterruptibly(attemptId: Long): MapStatus = { val numOutputSplits = dep.partitioner.numPartitions val taskContext = new TaskContext(stageId, partition, attemptId) @@ -133,7 +138,6 @@ private[spark] class ShuffleMapTask( val blockManager = SparkEnv.get.blockManager var shuffle: ShuffleBlocks = null var buckets: ShuffleWriterGroup = null - try { // Obtain all the block writers for shuffle blocks. val ser = SparkEnv.get.serializerManager.get(dep.serializerClass) diff --git a/core/src/main/scala/spark/scheduler/Task.scala b/core/src/main/scala/spark/scheduler/Task.scala index a6462c6968..9187e90a8a 100644 --- a/core/src/main/scala/spark/scheduler/Task.scala +++ b/core/src/main/scala/spark/scheduler/Task.scala @@ -1,23 +1,48 @@ package spark.scheduler -import spark.serializer.SerializerInstance import java.io.{DataInputStream, DataOutputStream} import java.nio.ByteBuffer -import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream -import spark.util.ByteBufferInputStream +import java.util.concurrent.{Callable, ExecutionException, Future, FutureTask} import scala.collection.mutable.HashMap +import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream import spark.executor.TaskMetrics +import spark.serializer.SerializerInstance +import spark.util.ByteBufferInputStream + /** * A task to execute on a worker node. */ private[spark] abstract class Task[T](val stageId: Int) extends Serializable { - def run(attemptId: Long): T + @volatile @transient var f: FutureTask[T] = null + + def run(attemptId: Long): T = { + f = new FutureTask(new Callable[T] { + def call(): T = { + runInterruptibly(attemptId) + } + }) + try { + f.run() + f.get() + } catch { + case e: Exception => throw e.getCause() + } + } + + def runInterruptibly(attemptId: Long): T + def preferredLocations: Seq[String] = Nil var generation: Long = -1 // Map output tracker generation. Will be set by TaskScheduler. var metrics: Option[TaskMetrics] = None + + def kill(): Unit = { + if (f != null) { + f.cancel(true) + } + } } diff --git a/core/src/main/scala/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/spark/scheduler/TaskScheduler.scala index 7787b54762..2506c73fe8 100644 --- a/core/src/main/scala/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/spark/scheduler/TaskScheduler.scala @@ -19,6 +19,8 @@ private[spark] trait TaskScheduler { // Submit a sequence of tasks to run. def submitTasks(taskSet: TaskSet): Unit + + def killTasks(stageId: Int): Unit // Set a listener for upcalls. This is guaranteed to be set before submitTasks is called. def setListener(listener: TaskSchedulerListener): Unit diff --git a/core/src/main/scala/spark/scheduler/TaskSet.scala b/core/src/main/scala/spark/scheduler/TaskSet.scala index e4b5fcaedb..564b6d3e85 100644 --- a/core/src/main/scala/spark/scheduler/TaskSet.scala +++ b/core/src/main/scala/spark/scheduler/TaskSet.scala @@ -15,4 +15,10 @@ private[spark] class TaskSet( val id: String = stageId + "." + attempt override def toString: String = "TaskSet " + id + + def kill() = { + tasks.foreach { + _.kill() + } + } } diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index 3a0c29b27f..b60c0023c9 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -199,6 +199,24 @@ private[spark] class ClusterScheduler(val sc: SparkContext) backend.reviveOffers() } + override def killTasks(stageId: Int) { + synchronized { + schedulableBuilder.popTaskSetManagers(stageId).foreach { + t => + val ts = t.asInstanceOf[TaskSetManager].taskSet + ts.kill() + val taskIds = taskSetTaskIds(ts.id) + taskIds.foreach { + tid => + val execId = taskIdToExecutorId(tid) + backend.killTask(tid, execId) + } + } + + } + + } + def taskSetFinished(manager: TaskSetManager) { this.synchronized { activeTaskSets -= manager.taskSet.id diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulableBuilder.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulableBuilder.scala index 18cc15c2a5..466da493e4 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SchedulableBuilder.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SchedulableBuilder.scala @@ -22,6 +22,7 @@ import java.util.Properties private[spark] trait SchedulableBuilder { def buildPools() def addTaskSetManager(manager: Schedulable, properties: Properties) + def popTaskSetManagers(stageId: Int): Iterable[Schedulable] } private[spark] class FIFOSchedulableBuilder(val rootPool: Pool) extends SchedulableBuilder with Logging { @@ -33,6 +34,16 @@ private[spark] class FIFOSchedulableBuilder(val rootPool: Pool) extends Schedula override def addTaskSetManager(manager: Schedulable, properties: Properties) { rootPool.addSchedulable(manager) } + + override def popTaskSetManagers(stageId: Int) = { + val s = rootPool.schedulableNameToSchedulable.values.filter { + _.stageId == stageId + } + s.foreach { + rootPool.removeSchedulable(_) + } + s + } } private[spark] class FairSchedulableBuilder(val rootPool: Pool) extends SchedulableBuilder with Logging { @@ -112,4 +123,14 @@ private[spark] class FairSchedulableBuilder(val rootPool: Pool) extends Schedula parentPool.addSchedulable(manager) logInfo("Added task set " + manager.name + " tasks to pool "+poolName) } + + override def popTaskSetManagers(stageId: Int) = { + val s = rootPool.schedulableNameToSchedulable.values.filter { + _.stageId == stageId + } + s.foreach { + rootPool.removeSchedulable(_) + } + s + } } diff --git a/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala index 9ac875de3a..56d7f81c87 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SchedulerBackend.scala @@ -10,6 +10,7 @@ import spark.Utils private[spark] trait SchedulerBackend { def start(): Unit def stop(): Unit + def killTask(taskId: Long, executorId: String): Unit def reviveOffers(): Unit def defaultParallelism(): Int @@ -22,6 +23,4 @@ private[spark] trait SchedulerBackend { .getOrElse(512) } - - // TODO: Probably want to add a killTask too } diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala index 3335294844..b05a47f9a9 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala @@ -10,6 +10,7 @@ private[spark] sealed trait StandaloneClusterMessage extends Serializable // Driver to executors private[spark] case class LaunchTask(task: TaskDescription) extends StandaloneClusterMessage +case class KillTask(taskId: Long, executorId: String) extends StandaloneClusterMessage private[spark] case class RegisteredExecutor(sparkProperties: Seq[(String, String)]) diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 004592a540..3e787665e4 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -66,6 +66,10 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor case ReviveOffers => makeOffers() + case KillTask(taskId, executorId) => + freeCores(executorId) += 1 + executorActor(executorId) ! KillTask(taskId, executorId) + case StopDriver => sender ! true context.stop(self) @@ -104,6 +108,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } } + // Remove a disconnected slave from the cluster def removeExecutor(executorId: String, reason: String) { if (executorActor.contains(executorId)) { @@ -156,6 +161,10 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor driverActor ! ReviveOffers } + def killTask(taskId: Long, executorId: String): Unit = { + driverActor ! KillTask(taskId, executorId) + } + override def defaultParallelism() = Option(System.getProperty("spark.default.parallelism")) .map(_.toInt).getOrElse(math.max(totalCoreCount.get(), 2)) diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index 93d4318b29..1439c51260 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -100,6 +100,15 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: localActor ! LocalReviveOffers } } + + override def killTasks(stageId: Int) { + synchronized { + schedulableBuilder.popTaskSetManagers(stageId).foreach { + _.asInstanceOf[TaskSetManager].taskSet.kill() + } + + } + } def resourceOffer(freeCores: Int): Seq[TaskDescription] = { synchronized { diff --git a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala index f4a2994b6d..a18fc3829f 100644 --- a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala @@ -130,6 +130,8 @@ private[spark] class CoarseMesosSchedulerBackend( override def disconnected(d: SchedulerDriver) {} override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {} + + override def killTask(taskId: Long, executorId: String) {} /** * Method called by Mesos to offer resources on slaves. We respond by launching an executor, diff --git a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala index ca7fab4cc5..d530167ebd 100644 --- a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala @@ -135,6 +135,8 @@ private[spark] class MesosSchedulerBackend( override def disconnected(d: SchedulerDriver) {} override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {} + + def killTask(taskId: Long, executorId: String): Unit = {} /** * Method called by Mesos to offer resources on slaves. We resond by asking our active task sets diff --git a/core/src/main/scala/spark/util/CompletionIterator.scala b/core/src/main/scala/spark/util/CompletionIterator.scala index 8139183780..ef228e610f 100644 --- a/core/src/main/scala/spark/util/CompletionIterator.scala +++ b/core/src/main/scala/spark/util/CompletionIterator.scala @@ -1,12 +1,14 @@ package spark.util +import spark.InterruptibleIterator + /** * Wrapper around an iterator which calls a completion method after it successfully iterates through all the elements */ -abstract class CompletionIterator[+A, +I <: Iterator[A]](sub: I) extends Iterator[A]{ - def next = sub.next - def hasNext = { - val r = sub.hasNext +abstract class CompletionIterator[+A, +I <: Iterator[A]](sub: I) extends InterruptibleIterator[A]{ + override def next = sub.next + override def hasNext = { + val r = super.hasNext && sub.hasNext if (!r) { completion } diff --git a/core/src/main/scala/spark/util/NextIterator.scala b/core/src/main/scala/spark/util/NextIterator.scala index 48b5018ddd..e76a2622f7 100644 --- a/core/src/main/scala/spark/util/NextIterator.scala +++ b/core/src/main/scala/spark/util/NextIterator.scala @@ -1,7 +1,9 @@ package spark.util +import spark.InterruptibleIterator + /** Provides a basic/boilerplate Iterator implementation. */ -private[spark] abstract class NextIterator[U] extends Iterator[U] { +private[spark] abstract class NextIterator[U] extends InterruptibleIterator[U] { private var gotNext = false private var nextValue: U = _ @@ -49,6 +51,14 @@ private[spark] abstract class NextIterator[U] extends Iterator[U] { } override def hasNext: Boolean = { + try { + super.hasNext + } catch { + case e: Exception => { + closeIfNeeded + throw e; + } + } if (!finished) { if (!gotNext) { nextValue = getNext() diff --git a/core/src/test/scala/spark/CancellationSuite.scala b/core/src/test/scala/spark/CancellationSuite.scala new file mode 100644 index 0000000000..d5c9e4d5d4 --- /dev/null +++ b/core/src/test/scala/spark/CancellationSuite.scala @@ -0,0 +1,79 @@ +package spark + +import java.util.concurrent.{Callable, CountDownLatch, Future, FutureTask, TimeUnit} +import scala.concurrent.ops.spawn +import org.scalatest.BeforeAndAfter +import org.scalatest.FunSuite +import spark.scheduler.Task + +class CancellationSuite extends FunSuite with BeforeAndAfter { + + test("Cancel Task") { + val latch = new CountDownLatch(2) + val startLatch = new CountDownLatch(1) + val task = new Task[Int](0) { + override def run(attemptId: Long) = { + try { + startLatch.countDown() + super.run(attemptId) + } catch { + //check if interrupted exception is thrown + case e: Exception => latch.countDown() + } + 0 + } + override def runInterruptibly(attemptId: Long) = { + //check if interrupt is propagated + while(!Thread.currentThread().isInterrupted()) {} + latch.countDown() + 0 + } + } + spawn { + task.run(0) + } + startLatch.await() + Thread.sleep(100) + task.kill + val v = latch.await(5,TimeUnit.SECONDS) + assert(latch.getCount() == 0 && v) + } + test("handle interrupt during iteration") { + val latch = new CountDownLatch(1) + val innerLatch = new CountDownLatch(1) + val iter = new Iterator[Int] with InterruptibleIterator[Int] { + override def hasNext(): Boolean = super.hasNext && true + override def next(): Int = 0 + } + val f = new FutureTask(new Callable[Int]{ + override def call(): Int = { + var count=0 + latch.countDown() + try{ + while (iter.hasNext) { + count += iter.next + } + } finally { + innerLatch.countDown() + } + count + } + }) + spawn { + latch.await() + Thread.sleep(100) + f.cancel(true) + } + f.run() + try { + f.get() + } catch { + case e: Exception => { + innerLatch.await(1, TimeUnit.SECONDS) + } + } finally { + assert(innerLatch.getCount() == 0) + } + } + +} \ No newline at end of file diff --git a/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala index 8e1ad27e14..6111a758eb 100644 --- a/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala +++ b/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala @@ -83,9 +83,10 @@ class DummyTaskSetManager( class DummyTask(stageId: Int) extends Task[Int](stageId) { - def run(attemptId: Long): Int = { + def runInterruptibly(attemptId: Long): Int = { return 0 } + } class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging { diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala index 30e6fef950..d25b08b2e7 100644 --- a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala @@ -48,6 +48,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont } override def setListener(listener: TaskSchedulerListener) = {} override def defaultParallelism() = 2 + override def killTasks(stageId: Int) = {} } var mapOutputTracker: MapOutputTracker = null diff --git a/core/src/test/scala/spark/util/NextIteratorSuite.scala b/core/src/test/scala/spark/util/NextIteratorSuite.scala index ed5b36da73..95874b37ed 100644 --- a/core/src/test/scala/spark/util/NextIteratorSuite.scala +++ b/core/src/test/scala/spark/util/NextIteratorSuite.scala @@ -3,7 +3,9 @@ package spark.util import org.scalatest.FunSuite import org.scalatest.matchers.ShouldMatchers import scala.collection.mutable.Buffer +import scala.concurrent.ops.spawn import java.util.NoSuchElementException +import java.util.concurrent.{Callable, CountDownLatch, Future, FutureTask, TimeUnit} class NextIteratorSuite extends FunSuite with ShouldMatchers { test("one iteration") { @@ -49,6 +51,54 @@ class NextIteratorSuite extends FunSuite with ShouldMatchers { i.closeCalled should be === 1 } + test("close is called upon interruption") { + val latch = new CountDownLatch(1) + val startLatch = new CountDownLatch(1) + var closeCalled = 0 + + val iter = new NextIterator[Int] { + + override def getNext() = { + if (latch.getCount() == 0) { + finished = true + 0 + } + 1 + } + + override def close() { + latch.countDown() + closeCalled += 1 + } + } + val f = new FutureTask(new Callable[Int]{ + override def call(): Int = { + var count=0 + startLatch.countDown() + while (iter.hasNext) { + count += iter.next + } + count + } + }) + + spawn { + startLatch.await() + Thread.sleep(100) + f.cancel(true) + } + f.run() + try { + f.get() + } catch { + case e: InterruptedException => { + latch.await(1, TimeUnit.SECONDS) + } + } finally { + assert(latch.getCount() == 0 && closeCalled == 1) + } + } + class StubIterator(ints: Buffer[Int]) extends NextIterator[Int] { var closeCalled = 0