diff --git a/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala b/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala index facbb830a60d..bb3bf743d7ca 100644 --- a/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala +++ b/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala @@ -17,6 +17,9 @@ package org.apache.spark.rdd.util +import scala.collection.Set +import scala.collection.mutable + import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -38,12 +41,15 @@ import org.apache.spark.util.PeriodicCheckpointer * - Unpersist RDDs from queue until there are at most 3 persisted RDDs. * - If using checkpointing and the checkpoint interval has been reached, * - Checkpoint the new RDD, and put in a queue of checkpointed RDDs. - * - Remove older checkpoints. + * - Remove older checkpoints except for created one and all the checkpoints it depends on. * * WARNINGS: * - This class should NOT be copied (since copies may conflict on which RDDs should be * checkpointed). - * - This class removes checkpoint files once later RDDs have been checkpointed. + * - This class removes checkpoint files once later RDDs have been checkpointed and do not + * have dependencies, the files to remove have been created for (removing checkpoint files + * of prior RDDs, the later ones depend on, may fail with `FileNotFoundException` in case + * the later RDDs are not yet materialized). * However, references to the older RDDs will still return isCheckpointed = true. * * Example usage: @@ -73,8 +79,6 @@ import org.apache.spark.util.PeriodicCheckpointer * * @param checkpointInterval RDDs will be checkpointed at this interval * @tparam T RDD element type - * - * TODO: Move this out of MLlib? */ private[spark] class PeriodicRDDCheckpointer[T]( checkpointInterval: Int, @@ -94,6 +98,34 @@ private[spark] class PeriodicRDDCheckpointer[T]( override protected def unpersist(data: RDD[T]): Unit = data.unpersist(blocking = false) override protected def getCheckpointFiles(data: RDD[T]): Iterable[String] = { - data.getCheckpointFile.map(x => x) + PeriodicRDDCheckpointer.rddDeps(data).flatMap(_.getCheckpointFile) + } + + override protected def haveCommonCheckpoint(newData: RDD[T], oldData: RDD[T]): Boolean = { + PeriodicRDDCheckpointer.haveCommonCheckpoint(Set(newData), Set(oldData)) + } + +} + +private[spark] object PeriodicRDDCheckpointer { + + def rddDeps(rdd: RDD[_]): Set[RDD[_]] = { + val parents = new mutable.HashSet[RDD[_]] + def visit(rdd: RDD[_]) { + parents.add(rdd) + rdd.dependencies.foreach(dep => visit(dep.rdd)) + } + visit(rdd) + parents + } + + def haveCommonCheckpoint(rdds1: Set[_ <: RDD[_]], rdds2: Set[_ <: RDD[_]]): Boolean = { + val deps1 = rdds1.foldLeft(new mutable.HashSet[RDD[_]]()) { (set, rdd) => + set ++= rddDeps(rdd) + } + val deps2 = rdds2.foldLeft(new mutable.HashSet[RDD[_]]()) { (set, rdd) => + set ++= rddDeps(rdd) + } + deps1.intersect(deps2).exists(_.isCheckpointed) } } diff --git a/core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala b/core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala index ce06e18879a4..9c416a950175 100644 --- a/core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala +++ b/core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala @@ -97,15 +97,7 @@ private[spark] abstract class PeriodicCheckpointer[T]( checkpoint(newData) checkpointQueue.enqueue(newData) // Remove checkpoints before the latest one. - var canDelete = true - while (checkpointQueue.size > 1 && canDelete) { - // Delete the oldest checkpoint only if the next checkpoint exists. - if (isCheckpointed(checkpointQueue.head)) { - removeCheckpointFile() - } else { - canDelete = false - } - } + deleteAllCheckpointsButLast() } } @@ -127,6 +119,11 @@ private[spark] abstract class PeriodicCheckpointer[T]( /** Get list of checkpoint files for this given Dataset */ protected def getCheckpointFiles(data: T): Iterable[String] + /** + * Checks whether the two datasets depend on the same checkpointed data. + */ + protected def haveCommonCheckpoint(newData: T, oldData: T): Boolean + /** * Call this to unpersist the Dataset. */ @@ -137,22 +134,38 @@ private[spark] abstract class PeriodicCheckpointer[T]( } } + /** + * Gets last checkpoint if it is available. + */ + def getLastCheckpoint: Option[T] = { + checkpointQueue.lastOption + } + /** * Call this at the end to delete any remaining checkpoint files. */ def deleteAllCheckpoints(): Unit = { - while (checkpointQueue.nonEmpty) { - removeCheckpointFile() - } + deleteAllCheckpoints(_ => true) + } + + /** + * Deletes all the checkpoints which match the given predicate. + */ + def deleteAllCheckpoints(f: T => Boolean): Unit = { + val checkpoints = checkpointQueue.dequeueAll(f) + checkpoints.foreach(removeCheckpointFile) } /** * Call this at the end to delete any remaining checkpoint files, except for the last checkpoint. - * Note that there may not be any checkpoints at all. + * Note that there may not be any checkpoints at all and in case there are more than one + * checkpoint, all the checkpoints, the last one depends on, will not be deleted. */ def deleteAllCheckpointsButLast(): Unit = { - while (checkpointQueue.size > 1) { - removeCheckpointFile() + getLastCheckpoint.foreach { last => + deleteAllCheckpoints { item => + item != last && !haveCommonCheckpoint(last, item) + } } } @@ -171,7 +184,14 @@ private[spark] abstract class PeriodicCheckpointer[T]( private def removeCheckpointFile(): Unit = { val old = checkpointQueue.dequeue() // Since the old checkpoint is not deleted by Spark, we manually delete it. - getCheckpointFiles(old).foreach( + removeCheckpointFile(old) + } + + /** + * Removes checkpoint files of the provided Dataset. + */ + private def removeCheckpointFile(item: T): Unit = { + getCheckpointFiles(item).foreach( PeriodicCheckpointer.removeCheckpointFile(_, sc.hadoopConfiguration)) } } diff --git a/core/src/test/scala/org/apache/spark/util/PeriodicRDDCheckpointerSuite.scala b/core/src/test/scala/org/apache/spark/util/PeriodicRDDCheckpointerSuite.scala index f9e1b791c86e..9ba8b8bc4a41 100644 --- a/core/src/test/scala/org/apache/spark/util/PeriodicRDDCheckpointerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/PeriodicRDDCheckpointerSuite.scala @@ -15,15 +15,15 @@ * limitations under the License. */ -package org.apache.spark.utils +package org.apache.spark.util import org.apache.hadoop.fs.Path import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite} +import org.apache.spark.rdd.MapPartitionsRDD import org.apache.spark.rdd.RDD import org.apache.spark.rdd.util.PeriodicRDDCheckpointer import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.Utils class PeriodicRDDCheckpointerSuite extends SparkFunSuite with SharedSparkContext { @@ -79,6 +79,197 @@ class PeriodicRDDCheckpointerSuite extends SparkFunSuite with SharedSparkContext Utils.deleteRecursively(tempDir) } + + test("Getting RDD dependencies should return RDD itself") { + val rdd = sc.emptyRDD[Int] + assert(PeriodicRDDCheckpointer.rddDeps(rdd) == Set(rdd)) + } + + test("Getting RDD dependencies should return all the DAG RDDs") { + val data = 0 until 10 + val initialRdd = sc.parallelize(data) + val targetRdd = data.foldLeft(initialRdd) { (rdd, num) => + rdd.filter(_ == num) + } + + val deps = PeriodicRDDCheckpointer.rddDeps(targetRdd) + assert(deps.size == data.size + 1) + assert(deps.count(_.isInstanceOf[MapPartitionsRDD[_, _]]) == data.size) + } + + test("Common checkpoint should be found when RDDs are related") { + val tempDir = Utils.createTempDir() + try { + sc.setCheckpointDir(tempDir.toURI.toString) + + val rdd1 = createRDD(sc) + rdd1.checkpoint() + rdd1.count() + + val rdd2 = rdd1.filter(_ => true) + + assert(PeriodicRDDCheckpointer.haveCommonCheckpoint(Set(rdd1), Set(rdd2))) + } finally { + Utils.deleteRecursively(tempDir) + } + } + + test("Common checkpoint should not be found when RDDs are unrelated") { + val tempDir = Utils.createTempDir() + try { + sc.setCheckpointDir(tempDir.toURI.toString) + + val rdd1 = createRDD(sc) + rdd1.checkpoint() + rdd1.count() + + val rdd2 = createRDD(sc) + rdd2.checkpoint() + rdd2.count() + + assert(!PeriodicRDDCheckpointer.haveCommonCheckpoint(Set(rdd1), Set(rdd2))) + } finally { + Utils.deleteRecursively(tempDir) + } + } + + test("Checkpointing of dependent RDD should not fail when materializing it") { + val tempDir = Utils.createTempDir() + try { + val checkpointInterval = 2 + sc.setCheckpointDir(tempDir.toURI.toString) + + val checkpointer = new PeriodicRDDCheckpointer[Double](checkpointInterval, sc) + + val rdd1 = createRDD(sc) + checkpointer.update(rdd1) + checkpointer.update(rdd1) + rdd1.count() + + val rdd2 = rdd1.filter(_ => true) + checkpointer.update(rdd2) + checkpointer.update(rdd2) + rdd2.count() + + checkpointer.deleteAllCheckpoints() + Seq(rdd1, rdd2).foreach { rdd => + confirmCheckpointRemoved(rdd) + } + } finally { + Utils.deleteRecursively(tempDir) + } + } + + test("deleteAllCheckpointsButLast should retain last checkpoint only when RDDs are unrelated") { + val tempDir = Utils.createTempDir() + try { + val checkpointInterval = 2 + sc.setCheckpointDir(tempDir.toURI.toString) + + val checkpointer = new PeriodicRDDCheckpointer[Double](checkpointInterval, sc) + + val rdd1 = createRDD(sc) + checkpointer.update(rdd1) + checkpointer.update(rdd1) + rdd1.count() + + val rdd2 = createRDD(sc) + checkpointer.update(rdd2) + checkpointer.update(rdd2) + + checkpointer.deleteAllCheckpointsButLast() + Seq(rdd1).foreach(confirmCheckpointRemoved) + Seq(rdd2).foreach(confirmCheckpointExists) + } finally { + Utils.deleteRecursively(tempDir) + } + } + + test("deleteAllCheckpointsButLast should retain last checkpoint and dependent checkpoints " + + "when RDDs are related") { + val tempDir = Utils.createTempDir() + try { + val checkpointInterval = 2 + sc.setCheckpointDir(tempDir.toURI.toString) + + val checkpointer = new PeriodicRDDCheckpointer[Double](checkpointInterval, sc) + + val rdd1 = createRDD(sc) + checkpointer.update(rdd1) + checkpointer.update(rdd1) + rdd1.count() + + val rdd2 = rdd1.filter(_ => true) + checkpointer.update(rdd2) + checkpointer.update(rdd2) + + checkpointer.deleteAllCheckpointsButLast() + Seq(rdd1, rdd2).foreach(confirmCheckpointExists) + } finally { + Utils.deleteRecursively(tempDir) + } + } + + test("deleteAllCheckpoints should remove all the checkpoints") { + val tempDir = Utils.createTempDir() + try { + val checkpointInterval = 2 + sc.setCheckpointDir(tempDir.toURI.toString) + + val checkpointer = new PeriodicRDDCheckpointer[Double](checkpointInterval, sc) + + val rdd1 = createRDD(sc) + checkpointer.update(rdd1) + checkpointer.update(rdd1) + rdd1.count() + + val rdd2 = rdd1.filter(_ => true) + checkpointer.update(rdd2) + checkpointer.update(rdd2) + + checkpointer.deleteAllCheckpoints() + Seq(rdd1, rdd2).foreach(confirmCheckpointRemoved) + } finally { + Utils.deleteRecursively(tempDir) + } + } + + test("deleteAllCheckpoints should remove all the checkpoint files when " + + "there is just parent checkpointed RDD") { + val tempDir = Utils.createTempDir() + try { + val checkpointInterval = 2 + sc.setCheckpointDir(tempDir.toURI.toString) + + val checkpointer = new PeriodicRDDCheckpointer[(Int, Int)](checkpointInterval, sc) + val rdd1 = sc.makeRDD((0 until 10).map(i => i -> i)).setName("rdd1") + + // rdd1 is not materialized yet, checkpointer(update=1, checkpointInterval=2) + checkpointer.update(rdd1) + // rdd2 depends on rdd1 + val rdd2 = rdd1.filter(_ => true).setName("rdd2") + + // rdd1 is materialized, checkpointer(update=2, checkpointInterval=2) + checkpointer.update(rdd1) + // rdd3 depends on rdd1 + val rdd3 = rdd1.filter(_ => true).setName("rdd3") + + // rdd3 is not materialized yet, checkpointer(update=3, checkpointInterval=2) + checkpointer.update(rdd3) + // rdd3 is materialized, rdd1 is removed, checkpointer(update=4, checkpointInterval=2) + checkpointer.update(rdd3) + + // should not fail + rdd2.count() + + checkpointer.deleteAllCheckpoints() + Seq(rdd1, rdd2, rdd3).foreach { rdd => + confirmCheckpointRemoved(rdd) + } + } finally { + Utils.deleteRecursively(tempDir) + } + } } private object PeriodicRDDCheckpointerSuite { @@ -135,6 +326,15 @@ private object PeriodicRDDCheckpointerSuite { } } + def confirmCheckpointExists(rdd: RDD[_]): Unit = { + val hadoopConf = rdd.sparkContext.hadoopConfiguration + rdd.getCheckpointFile.foreach { checkpointFile => + val path = new Path(checkpointFile) + val fs = path.getFileSystem(hadoopConf) + assert(fs.exists(path), "RDD checkpoint file should not have been removed") + } + } + /** * Check checkpointed status of rdd. * @param gIndex Index of rdd in order inserted into checkpointer (from 1). diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index 755c6febc48e..42632a7c1357 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -170,8 +170,42 @@ object Pregel extends Logging { i += 1 } messageCheckpointer.unpersistDataSet() - graphCheckpointer.deleteAllCheckpoints() - messageCheckpointer.deleteAllCheckpoints() + + // in case of low system resources when cached RDDs are going to be evicted from memory + // there may be a chance that it will be necessary to read checkpointed files from disk, + // in that case all the checkpoints, the resulting graph depends on, should not be deleted + import PeriodicRDDCheckpointer._ + + val graphDeps = rddDeps(g.vertices) ++ rddDeps(g.edges) + val lastGraphCheckpoint = graphCheckpointer.getLastCheckpoint + val lastGraphCheckpointDeps = lastGraphCheckpoint match { + case Some(value) => rddDeps(value.vertices) ++ rddDeps(value.edges) + case _ => Set.empty + } + + val messagesDeps = rddDeps(messages) + val lastMessagesCheckpoint = messageCheckpointer.getLastCheckpoint + val lastMessagesCheckpointDeps = lastMessagesCheckpoint match { + case Some(value) => rddDeps(value) + case _ => Set.empty + } + + graphCheckpointer.deleteAllCheckpoints { item => + val itemDeps = rddDeps(item.vertices) ++ rddDeps(item.edges) + !lastGraphCheckpoint.exists(_ eq item) && + !haveCommonCheckpoint(itemDeps, lastGraphCheckpointDeps) && + !haveCommonCheckpoint(messagesDeps, itemDeps) && + !haveCommonCheckpoint(messagesDeps, lastGraphCheckpointDeps) + } + + messageCheckpointer.deleteAllCheckpoints { item => + val itemDeps = rddDeps(item) + !lastMessagesCheckpoint.exists(_ eq item) && + !haveCommonCheckpoint(itemDeps, lastMessagesCheckpointDeps) && + !haveCommonCheckpoint(graphDeps, itemDeps) && + !haveCommonCheckpoint(graphDeps, lastMessagesCheckpointDeps) + } + g } // end of apply diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointer.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointer.scala index 539b66f747cc..a57d2226a3dc 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointer.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointer.scala @@ -19,6 +19,7 @@ package org.apache.spark.graphx.util import org.apache.spark.SparkContext import org.apache.spark.graphx.Graph +import org.apache.spark.rdd.util.PeriodicRDDCheckpointer import org.apache.spark.storage.StorageLevel import org.apache.spark.util.PeriodicCheckpointer @@ -38,12 +39,15 @@ import org.apache.spark.util.PeriodicCheckpointer * - Unpersist graphs from queue until there are at most 3 persisted graphs. * - If using checkpointing and the checkpoint interval has been reached, * - Checkpoint the new graph, and put in a queue of checkpointed graphs. - * - Remove older checkpoints. + * - Remove older checkpoints except for created one and all the checkpoints it depends on. * * WARNINGS: * - This class should NOT be copied (since copies may conflict on which Graphs should be * checkpointed). - * - This class removes checkpoint files once later graphs have been checkpointed. + * - This class removes checkpoint files once later graphs have been checkpointed and do not + * have dependencies, the files to remove have been created for (removing checkpoint files + * of prior graphs, the later ones depend on, may fail with `FileNotFoundException` in case + * the later graphs are not yet materialized). * However, references to the older graphs will still return isCheckpointed = true. * * Example usage: @@ -101,6 +105,20 @@ private[spark] class PeriodicGraphCheckpointer[VD, ED]( override protected def unpersist(data: Graph[VD, ED]): Unit = data.unpersist(blocking = false) override protected def getCheckpointFiles(data: Graph[VD, ED]): Iterable[String] = { - data.getCheckpointFiles + val verticesFiles = PeriodicRDDCheckpointer + .rddDeps(data.vertices) + .flatMap(_.getCheckpointFile) + val edgesFiles = PeriodicRDDCheckpointer + .rddDeps(data.edges) + .flatMap(_.getCheckpointFile) + + verticesFiles ++ edgesFiles + } + + override protected def haveCommonCheckpoint( + newData: Graph[VD, ED], oldData: Graph[VD, ED]): Boolean = { + PeriodicRDDCheckpointer.haveCommonCheckpoint( + Set(newData.vertices, newData.edges), Set(oldData.vertices, oldData.edges)) } + } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala index 66c4747fec26..fdce727cd2a1 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala @@ -26,8 +26,7 @@ import org.apache.spark.SparkContext */ trait LocalSparkContext { /** Runs `f` on a new SparkContext and ensures that it is stopped afterwards. */ - def withSpark[T](f: SparkContext => T): T = { - val conf = new SparkConf() + def withSpark[T](f: SparkContext => T)(implicit conf: SparkConf = new SparkConf()): T = { GraphXUtils.registerKryoClasses(conf) val sc = new SparkContext("local", "test", conf) try { diff --git a/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala index 90a9ac613ef9..85f78c9abd4b 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/PregelSuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.graphx +import org.apache.spark.SparkConf import org.apache.spark.SparkFunSuite +import org.apache.spark.util.Utils class PregelSuite extends SparkFunSuite with LocalSparkContext { @@ -52,4 +54,55 @@ class PregelSuite extends SparkFunSuite with LocalSparkContext { chain.vertices.mapValues { (vid, attr) => attr + 1 }.collect.toSet) } } + + test("should preserve intermediate checkpoint files when there are even amount of iterations") { + withEvictedGraph(iterations = 4) { _ => } + } + + test("should preserve intermediate checkpoint files when there are odd amount of iterations") { + withEvictedGraph(iterations = 5) { _ => } + } + + test("preserve last checkpoint files when there are even amount of iterations") { + withEvictedGraph(iterations = 4) { graph => + graph.vertices.count() + graph.edges.count() + } + } + + test("preserve last checkpoint files when there are odd amount of iterations") { + withEvictedGraph(iterations = 5) { graph => + graph.vertices.count() + graph.edges.count() + } + } + + private def withEvictedGraph(iterations: Int)(f: Graph[Long, Int] => Unit): Unit = { + implicit val conf: SparkConf = new SparkConf() + .set("spark.graphx.pregel.checkpointInterval", "2") + // set testing memory to evict cached RDDs from it and force + // reading checkpointed RDDs from disk + .set("spark.testing.reservedMemory", "128") + .set("spark.testing.memory", "256") + withSpark { sc => + val dir = Utils.createTempDir().getCanonicalFile + try { + sc.setCheckpointDir(dir.toURI.toString) + val edges = (1 to iterations).map(x => (x: VertexId, x + 1: VertexId)) + val graph = Pregel(Graph.fromEdgeTuples(sc.parallelize(edges, 3), 0L), 1L)( + (vid, attr, msg) => if (vid == msg) msg else attr, + et => + if (et.dstId != et.dstAttr && et.srcId < et.dstId) { + Iterator((et.dstId, et.srcAttr + 1)) + } else { + Iterator.empty + }, + (a: Long, b: Long) => math.max(a, b)) + f(graph) + } finally { + Utils.deleteRecursively(dir) + } + } + } + }