Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.spark.rdd.util

import scala.collection.Set
import scala.collection.mutable

import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
Expand All @@ -38,12 +41,15 @@ import org.apache.spark.util.PeriodicCheckpointer
* - Unpersist RDDs from queue until there are at most 3 persisted RDDs.
* - If using checkpointing and the checkpoint interval has been reached,
* - Checkpoint the new RDD, and put in a queue of checkpointed RDDs.
* - Remove older checkpoints.
* - Remove older checkpoints except for created one and all the checkpoints it depends on.
*
* WARNINGS:
* - This class should NOT be copied (since copies may conflict on which RDDs should be
* checkpointed).
* - This class removes checkpoint files once later RDDs have been checkpointed.
* - This class removes checkpoint files once later RDDs have been checkpointed and do not
* have dependencies, the files to remove have been created for (removing checkpoint files
* of prior RDDs, the later ones depend on, may fail with `FileNotFoundException` in case
* the later RDDs are not yet materialized).
* However, references to the older RDDs will still return isCheckpointed = true.
*
* Example usage:
Expand Down Expand Up @@ -73,8 +79,6 @@ import org.apache.spark.util.PeriodicCheckpointer
*
* @param checkpointInterval RDDs will be checkpointed at this interval
* @tparam T RDD element type
*
* TODO: Move this out of MLlib?
*/
private[spark] class PeriodicRDDCheckpointer[T](
checkpointInterval: Int,
Expand All @@ -94,6 +98,34 @@ private[spark] class PeriodicRDDCheckpointer[T](
override protected def unpersist(data: RDD[T]): Unit = data.unpersist(blocking = false)

override protected def getCheckpointFiles(data: RDD[T]): Iterable[String] = {
data.getCheckpointFile.map(x => x)
PeriodicRDDCheckpointer.rddDeps(data).flatMap(_.getCheckpointFile)
}

override protected def haveCommonCheckpoint(newData: RDD[T], oldData: RDD[T]): Boolean = {
PeriodicRDDCheckpointer.haveCommonCheckpoint(Set(newData), Set(oldData))
}

}

private[spark] object PeriodicRDDCheckpointer {

def rddDeps(rdd: RDD[_]): Set[RDD[_]] = {
val parents = new mutable.HashSet[RDD[_]]
def visit(rdd: RDD[_]) {
parents.add(rdd)
rdd.dependencies.foreach(dep => visit(dep.rdd))
}
visit(rdd)
parents
}

def haveCommonCheckpoint(rdds1: Set[_ <: RDD[_]], rdds2: Set[_ <: RDD[_]]): Boolean = {
val deps1 = rdds1.foldLeft(new mutable.HashSet[RDD[_]]()) { (set, rdd) =>
set ++= rddDeps(rdd)
}
val deps2 = rdds2.foldLeft(new mutable.HashSet[RDD[_]]()) { (set, rdd) =>
set ++= rddDeps(rdd)
}
deps1.intersect(deps2).exists(_.isCheckpointed)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,7 @@ private[spark] abstract class PeriodicCheckpointer[T](
checkpoint(newData)
checkpointQueue.enqueue(newData)
// Remove checkpoints before the latest one.
var canDelete = true
while (checkpointQueue.size > 1 && canDelete) {
// Delete the oldest checkpoint only if the next checkpoint exists.
if (isCheckpointed(checkpointQueue.head)) {
removeCheckpointFile()
} else {
canDelete = false
}
}
deleteAllCheckpointsButLast()
}
}

Expand All @@ -127,6 +119,11 @@ private[spark] abstract class PeriodicCheckpointer[T](
/** Get list of checkpoint files for this given Dataset */
protected def getCheckpointFiles(data: T): Iterable[String]

/**
* Checks whether the two datasets depend on the same checkpointed data.
*/
protected def haveCommonCheckpoint(newData: T, oldData: T): Boolean

/**
* Call this to unpersist the Dataset.
*/
Expand All @@ -137,22 +134,38 @@ private[spark] abstract class PeriodicCheckpointer[T](
}
}

/**
* Gets last checkpoint if it is available.
*/
def getLastCheckpoint: Option[T] = {
checkpointQueue.lastOption
}

/**
* Call this at the end to delete any remaining checkpoint files.
*/
def deleteAllCheckpoints(): Unit = {
while (checkpointQueue.nonEmpty) {
removeCheckpointFile()
}
deleteAllCheckpoints(_ => true)
}

/**
* Deletes all the checkpoints which match the given predicate.
*/
def deleteAllCheckpoints(f: T => Boolean): Unit = {
val checkpoints = checkpointQueue.dequeueAll(f)
checkpoints.foreach(removeCheckpointFile)
}

/**
* Call this at the end to delete any remaining checkpoint files, except for the last checkpoint.
* Note that there may not be any checkpoints at all.
* Note that there may not be any checkpoints at all and in case there are more than one
* checkpoint, all the checkpoints, the last one depends on, will not be deleted.
*/
def deleteAllCheckpointsButLast(): Unit = {
while (checkpointQueue.size > 1) {
removeCheckpointFile()
getLastCheckpoint.foreach { last =>
deleteAllCheckpoints { item =>
item != last && !haveCommonCheckpoint(last, item)
}
}
}

Expand All @@ -171,7 +184,14 @@ private[spark] abstract class PeriodicCheckpointer[T](
private def removeCheckpointFile(): Unit = {
val old = checkpointQueue.dequeue()
// Since the old checkpoint is not deleted by Spark, we manually delete it.
getCheckpointFiles(old).foreach(
removeCheckpointFile(old)
}

/**
* Removes checkpoint files of the provided Dataset.
*/
private def removeCheckpointFile(item: T): Unit = {
getCheckpointFiles(item).foreach(
PeriodicCheckpointer.removeCheckpointFile(_, sc.hadoopConfiguration))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
* limitations under the License.
*/

package org.apache.spark.utils
package org.apache.spark.util

import org.apache.hadoop.fs.Path

import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite}
import org.apache.spark.rdd.MapPartitionsRDD
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.util.PeriodicRDDCheckpointer
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils


class PeriodicRDDCheckpointerSuite extends SparkFunSuite with SharedSparkContext {
Expand Down Expand Up @@ -79,6 +79,197 @@ class PeriodicRDDCheckpointerSuite extends SparkFunSuite with SharedSparkContext

Utils.deleteRecursively(tempDir)
}

test("Getting RDD dependencies should return RDD itself") {
val rdd = sc.emptyRDD[Int]
assert(PeriodicRDDCheckpointer.rddDeps(rdd) == Set(rdd))
}

test("Getting RDD dependencies should return all the DAG RDDs") {
val data = 0 until 10
val initialRdd = sc.parallelize(data)
val targetRdd = data.foldLeft(initialRdd) { (rdd, num) =>
rdd.filter(_ == num)
}

val deps = PeriodicRDDCheckpointer.rddDeps(targetRdd)
assert(deps.size == data.size + 1)
assert(deps.count(_.isInstanceOf[MapPartitionsRDD[_, _]]) == data.size)
}

test("Common checkpoint should be found when RDDs are related") {
val tempDir = Utils.createTempDir()
try {
sc.setCheckpointDir(tempDir.toURI.toString)

val rdd1 = createRDD(sc)
rdd1.checkpoint()
rdd1.count()

val rdd2 = rdd1.filter(_ => true)

assert(PeriodicRDDCheckpointer.haveCommonCheckpoint(Set(rdd1), Set(rdd2)))
} finally {
Utils.deleteRecursively(tempDir)
}
}

test("Common checkpoint should not be found when RDDs are unrelated") {
val tempDir = Utils.createTempDir()
try {
sc.setCheckpointDir(tempDir.toURI.toString)

val rdd1 = createRDD(sc)
rdd1.checkpoint()
rdd1.count()

val rdd2 = createRDD(sc)
rdd2.checkpoint()
rdd2.count()

assert(!PeriodicRDDCheckpointer.haveCommonCheckpoint(Set(rdd1), Set(rdd2)))
} finally {
Utils.deleteRecursively(tempDir)
}
}

test("Checkpointing of dependent RDD should not fail when materializing it") {
val tempDir = Utils.createTempDir()
try {
val checkpointInterval = 2
sc.setCheckpointDir(tempDir.toURI.toString)

val checkpointer = new PeriodicRDDCheckpointer[Double](checkpointInterval, sc)

val rdd1 = createRDD(sc)
checkpointer.update(rdd1)
checkpointer.update(rdd1)
rdd1.count()

val rdd2 = rdd1.filter(_ => true)
checkpointer.update(rdd2)
checkpointer.update(rdd2)
rdd2.count()

checkpointer.deleteAllCheckpoints()
Seq(rdd1, rdd2).foreach { rdd =>
confirmCheckpointRemoved(rdd)
}
} finally {
Utils.deleteRecursively(tempDir)
}
}

test("deleteAllCheckpointsButLast should retain last checkpoint only when RDDs are unrelated") {
val tempDir = Utils.createTempDir()
try {
val checkpointInterval = 2
sc.setCheckpointDir(tempDir.toURI.toString)

val checkpointer = new PeriodicRDDCheckpointer[Double](checkpointInterval, sc)

val rdd1 = createRDD(sc)
checkpointer.update(rdd1)
checkpointer.update(rdd1)
rdd1.count()

val rdd2 = createRDD(sc)
checkpointer.update(rdd2)
checkpointer.update(rdd2)

checkpointer.deleteAllCheckpointsButLast()
Seq(rdd1).foreach(confirmCheckpointRemoved)
Seq(rdd2).foreach(confirmCheckpointExists)
} finally {
Utils.deleteRecursively(tempDir)
}
}

test("deleteAllCheckpointsButLast should retain last checkpoint and dependent checkpoints " +
"when RDDs are related") {
val tempDir = Utils.createTempDir()
try {
val checkpointInterval = 2
sc.setCheckpointDir(tempDir.toURI.toString)

val checkpointer = new PeriodicRDDCheckpointer[Double](checkpointInterval, sc)

val rdd1 = createRDD(sc)
checkpointer.update(rdd1)
checkpointer.update(rdd1)
rdd1.count()

val rdd2 = rdd1.filter(_ => true)
checkpointer.update(rdd2)
checkpointer.update(rdd2)

checkpointer.deleteAllCheckpointsButLast()
Seq(rdd1, rdd2).foreach(confirmCheckpointExists)
} finally {
Utils.deleteRecursively(tempDir)
}
}

test("deleteAllCheckpoints should remove all the checkpoints") {
val tempDir = Utils.createTempDir()
try {
val checkpointInterval = 2
sc.setCheckpointDir(tempDir.toURI.toString)

val checkpointer = new PeriodicRDDCheckpointer[Double](checkpointInterval, sc)

val rdd1 = createRDD(sc)
checkpointer.update(rdd1)
checkpointer.update(rdd1)
rdd1.count()

val rdd2 = rdd1.filter(_ => true)
checkpointer.update(rdd2)
checkpointer.update(rdd2)

checkpointer.deleteAllCheckpoints()
Seq(rdd1, rdd2).foreach(confirmCheckpointRemoved)
} finally {
Utils.deleteRecursively(tempDir)
}
}

test("deleteAllCheckpoints should remove all the checkpoint files when " +
"there is just parent checkpointed RDD") {
val tempDir = Utils.createTempDir()
try {
val checkpointInterval = 2
sc.setCheckpointDir(tempDir.toURI.toString)

val checkpointer = new PeriodicRDDCheckpointer[(Int, Int)](checkpointInterval, sc)
val rdd1 = sc.makeRDD((0 until 10).map(i => i -> i)).setName("rdd1")

// rdd1 is not materialized yet, checkpointer(update=1, checkpointInterval=2)
checkpointer.update(rdd1)
// rdd2 depends on rdd1
val rdd2 = rdd1.filter(_ => true).setName("rdd2")

// rdd1 is materialized, checkpointer(update=2, checkpointInterval=2)
checkpointer.update(rdd1)
// rdd3 depends on rdd1
val rdd3 = rdd1.filter(_ => true).setName("rdd3")

// rdd3 is not materialized yet, checkpointer(update=3, checkpointInterval=2)
checkpointer.update(rdd3)
// rdd3 is materialized, rdd1 is removed, checkpointer(update=4, checkpointInterval=2)
checkpointer.update(rdd3)

// should not fail
rdd2.count()

checkpointer.deleteAllCheckpoints()
Seq(rdd1, rdd2, rdd3).foreach { rdd =>
confirmCheckpointRemoved(rdd)
}
} finally {
Utils.deleteRecursively(tempDir)
}
}
}

private object PeriodicRDDCheckpointerSuite {
Expand Down Expand Up @@ -135,6 +326,15 @@ private object PeriodicRDDCheckpointerSuite {
}
}

def confirmCheckpointExists(rdd: RDD[_]): Unit = {
val hadoopConf = rdd.sparkContext.hadoopConfiguration
rdd.getCheckpointFile.foreach { checkpointFile =>
val path = new Path(checkpointFile)
val fs = path.getFileSystem(hadoopConf)
assert(fs.exists(path), "RDD checkpoint file should not have been removed")
}
}

/**
* Check checkpointed status of rdd.
* @param gIndex Index of rdd in order inserted into checkpointer (from 1).
Expand Down
Loading