diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
index b320be8863..87a4143877 100644
--- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
@@ -60,6 +60,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
 
   val nextRunId = new AtomicInteger(0)
 
+  val runIdToStageIds = new HashMap[Int, HashSet[Int]]
+
   val nextStageId = new AtomicInteger(0)
 
   val idToStage = new TimeStampedHashMap[Int, Stage]
@@ -143,6 +145,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
     val id = nextStageId.getAndIncrement()
     val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, priority), priority)
     idToStage(id) = stage
+    val stageIdSet = runIdToStageIds.getOrElseUpdate(priority, new HashSet)
+    stageIdSet += id
     stage
   }
 
@@ -285,6 +289,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
         case StopDAGScheduler =>
           // Cancel any active jobs
           for (job <- activeJobs) {
+            removeStages(job)
             val error = new SparkException("Job cancelled because SparkContext was shut down")
             job.listener.jobFailed(error)
           }
@@ -420,13 +425,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
                 if (!job.finished(rt.outputId)) {
                   job.finished(rt.outputId) = true
                   job.numFinished += 1
-                  job.listener.taskSucceeded(rt.outputId, event.result)
                   // If the whole job has finished, remove it
                   if (job.numFinished == job.numPartitions) {
                     activeJobs -= job
                     resultStageToJob -= stage
                     running -= stage
+                    removeStages(job)
                   }
+                  job.listener.taskSucceeded(rt.outputId, event.result)
                 }
               case None =>
                 logInfo("Ignoring result from " + rt + " because its job has finished")
@@ -558,9 +564,10 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
     val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
     for (resultStage <- dependentStages) {
       val job = resultStageToJob(resultStage)
-      job.listener.jobFailed(new SparkException("Job failed: " + reason))
       activeJobs -= job
       resultStageToJob -= resultStage
+      removeStages(job)
+      job.listener.jobFailed(new SparkException("Job failed: " + reason))
     }
     if (dependentStages.isEmpty) {
       logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done")
@@ -637,6 +644,19 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
     logInfo("pendingTasks " + sizeBefore + " --> " + pendingTasks.size)
   }
 
+  def removeStages(job: ActiveJob) = {
+    runIdToStageIds(job.runId).foreach(stageId => {
+      idToStage.get(stageId).map( stage => {
+        pendingTasks -= stage
+        waiting -= stage
+        running -= stage
+        failed -= stage
+      })
+      idToStage -= stageId
+    })
+    runIdToStageIds -= job.runId
+  }
+
   def stop() {
     eventQueue.put(StopDAGScheduler)
     metadataCleaner.cancel()
diff --git a/core/src/test/scala/spark/DAGSchedulerSuite.scala b/core/src/test/scala/spark/DAGSchedulerSuite.scala
new file mode 100644
index 0000000000..2a3b30ae42
--- /dev/null
+++ b/core/src/test/scala/spark/DAGSchedulerSuite.scala
@@ -0,0 +1,88 @@
+package spark
+
+import org.scalatest.FunSuite
+import scheduler.{DAGScheduler, TaskSchedulerListener, TaskSet, TaskScheduler}
+import collection.mutable
+
+class TaskSchedulerMock(f: (Int) => TaskEndReason ) extends TaskScheduler {
+  // Listener object to pass upcalls into
+  var listener: TaskSchedulerListener = null
+  var taskCount = 0
+
+  override def start(): Unit = {}
+
+  // Disconnect from the cluster.
+  override def stop(): Unit = {}
+
+  // Submit a sequence of tasks to run.
+  override def submitTasks(taskSet: TaskSet): Unit = {
+    taskSet.tasks.foreach( task => {
+      val m = new mutable.HashMap[Long, Any]()
+      m.put(task.stageId, 1)
+      taskCount += 1
+      listener.taskEnded(task, f(taskCount), 1, m)
+    })
+  }
+
+  // Set a listener for upcalls. This is guaranteed to be set before submitTasks is called.
+  override def setListener(listener: TaskSchedulerListener) {
+    this.listener = listener
+  }
+
+  // Get the default level of parallelism to use in the cluster, as a hint for sizing jobs.
+  override def defaultParallelism(): Int = {
+    2
+  }
+}
+
+class DAGSchedulerSuite extends FunSuite {
+  def assertDagSchedulerEmpty(dagScheduler: DAGScheduler) = {
+    assert(dagScheduler.pendingTasks.isEmpty)
+    assert(dagScheduler.activeJobs.isEmpty)
+    assert(dagScheduler.failed.isEmpty)
+    assert(dagScheduler.runIdToStageIds.isEmpty)
+    assert(dagScheduler.idToStage.isEmpty)
+    assert(dagScheduler.resultStageToJob.isEmpty)
+    assert(dagScheduler.running.isEmpty)
+    assert(dagScheduler.shuffleToMapStage.isEmpty)
+    assert(dagScheduler.waiting.isEmpty)
+  }
+
+  test("oneGoodJob") {
+    val sc = new SparkContext("local", "test")
+    val dagScheduler = new DAGScheduler(new TaskSchedulerMock(count => Success))
+    try {
+      val rdd = new ParallelCollection(sc, 1.to(100).toSeq, 5, Map.empty)
+      val func = (tc: TaskContext, iter: Iterator[Int]) => 1
+      val callSite = Utils.getSparkCallSite
+
+      val result = dagScheduler.runJob(rdd, func, 0 until rdd.splits.size, callSite, false)
+      assertDagSchedulerEmpty(dagScheduler)
+    } finally {
+      dagScheduler.stop()
+      sc.stop()
+      // pause to let dagScheduler stop (separate thread)
+      Thread.sleep(10)
+    }
+  }
+
+  test("manyGoodJobs") {
+    val sc = new SparkContext("local", "test")
+    val dagScheduler = new DAGScheduler(new TaskSchedulerMock(count => Success))
+    try {
+      val rdd = new ParallelCollection(sc, 1.to(100).toSeq, 5, Map.empty)
+      val func = (tc: TaskContext, iter: Iterator[Int]) => 1
+      val callSite = Utils.getSparkCallSite
+
+      1.to(100).foreach( v => {
+        val result = dagScheduler.runJob(rdd, func, 0 until rdd.splits.size, callSite, false)
+      })
+      assertDagSchedulerEmpty(dagScheduler)
+    } finally {
+      dagScheduler.stop()
+      sc.stop()
+      // pause to let dagScheduler stop (separate thread)
+      Thread.sleep(10)
+    }
+  }
+}