Skip to content

Commit

Permalink
[jvm-packages] cleaning checkpoint file after a successful training (#…
Browse files Browse the repository at this point in the history
…4754)

* cleaning checkpoint file after a successful file

* address comments
  • Loading branch information
CodingCat authored Aug 14, 2019
1 parent ef9af33 commit 7b5cbcc
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String)
}
}

def cleanPath(): Unit = {
if (checkpointPath != "") {
FileSystem.get(sc.hadoopConfiguration).delete(new Path(checkpointPath), true)
}
}

/**
* Load existing checkpoint with the highest version as a Booster object
*
Expand Down Expand Up @@ -127,7 +133,12 @@ private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String)

object CheckpointManager {

private[spark] def extractParams(params: Map[String, Any]): (String, Int) = {
case class CheckpointParam(
checkpointPath: String,
checkpointInterval: Int,
skipCleanCheckpoint: Boolean)

private[spark] def extractParams(params: Map[String, Any]): CheckpointParam = {
val checkpointPath: String = params.get("checkpoint_path") match {
case None => ""
case Some(path: String) => path
Expand All @@ -141,6 +152,13 @@ object CheckpointManager {
case _ => throw new IllegalArgumentException("parameter \"checkpoint_interval\" must be" +
" an instance of Int.")
}
(checkpointPath, checkpointInterval)

val skipCheckpointFile: Boolean = params.get("skip_clean_checkpoint") match {
case None => false
case Some(skipCleanCheckpoint: Boolean) => skipCleanCheckpoint
case _ => throw new IllegalArgumentException("parameter \"skip_clean_checkpoint\" must be" +
" an instance of Boolean")
}
CheckpointParam(checkpointPath, checkpointInterval, skipCheckpointFile)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -331,9 +331,11 @@ object XGBoost extends Serializable {
case _ => throw new IllegalArgumentException("parameter \"timeout_request_workers\" must be" +
" an instance of Long.")
}
val (checkpointPath, checkpointInterval) = CheckpointManager.extractParams(params)
val checkpointParam =
CheckpointManager.extractParams(params)
(nWorkers, round, useExternalMemory, obj, eval, missing, trackerConf, timeoutRequestWorkers,
checkpointPath, checkpointInterval)
checkpointParam.checkpointPath, checkpointParam.checkpointInterval,
checkpointParam.skipCleanCheckpoint)
}

private def trainForNonRanking(
Expand All @@ -343,7 +345,7 @@ object XGBoost extends Serializable {
checkpointRound: Int,
prevBooster: Booster,
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _) =
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _, _) =
parameterFetchAndValidation(params, trainingData.sparkContext)
if (evalSetsMap.isEmpty) {
trainingData.mapPartitions(labeledPoints => {
Expand Down Expand Up @@ -373,7 +375,7 @@ object XGBoost extends Serializable {
checkpointRound: Int,
prevBooster: Booster,
evalSetsMap: Map[String, RDD[XGBLabeledPoint]]): RDD[(Booster, Map[String, Array[Float]])] = {
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _) =
val (nWorkers, _, useExternalMemory, obj, eval, missing, _, _, _, _, _) =
parameterFetchAndValidation(params, trainingData.sparkContext)
if (evalSetsMap.isEmpty) {
trainingData.mapPartitions(labeledPointGroups => {
Expand Down Expand Up @@ -427,7 +429,8 @@ object XGBoost extends Serializable {
(Booster, Map[String, Array[Float]]) = {
logger.info(s"Running XGBoost ${spark.VERSION} with parameters:\n${params.mkString("\n")}")
val (nWorkers, round, _, _, _, _, trackerConf, timeoutRequestWorkers,
checkpointPath, checkpointInterval) = parameterFetchAndValidation(params,
checkpointPath, checkpointInterval, skipCleanCheckpoint) =
parameterFetchAndValidation(params,
trainingData.sparkContext)
val sc = trainingData.sparkContext
val checkpointManager = new CheckpointManager(sc, checkpointPath)
Expand All @@ -437,7 +440,7 @@ object XGBoost extends Serializable {
var prevBooster = checkpointManager.loadCheckpointAsBooster
try {
// Train for every ${savingRound} rounds and save the partially completed booster
checkpointManager.getCheckpointRounds(checkpointInterval, round).map {
val producedBooster = checkpointManager.getCheckpointRounds(checkpointInterval, round).map {
checkpointRound: Int =>
val tracker = startTracker(nWorkers, trackerConf)
try {
Expand Down Expand Up @@ -473,6 +476,11 @@ object XGBoost extends Serializable {
tracker.stop()
}
}.last
// we should delete the checkpoint directory after a successful training
if (!skipCleanCheckpoint) {
checkpointManager.cleanPath()
}
producedBooster
} catch {
case t: Throwable =>
// if the job was aborted due to an exception
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,13 @@ private[spark] trait LearningTaskParams extends Params {
final val cacheTrainingSet = new BooleanParam(this, "cacheTrainingSet",
"whether caching training data")

/**
* whether cleaning checkpoint, always cleaning by default, having this parameter majorly for
* testing
*/
final val skipCleanCheckpoint = new BooleanParam(this, "skipCleanCheckpoint",
"whether cleaning checkpoint data")

/**
* If non-zero, the training will be stopped after a specified number
* of consecutive increases in any evaluation metric.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark

import java.io.File

import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
import org.scalatest.FunSuite
import org.apache.hadoop.fs.{FileSystem, Path}

Expand Down Expand Up @@ -67,4 +68,50 @@ class CheckpointManagerSuite extends FunSuite with TmpFolderPerSuite with PerTes
assertResult(Seq(4, 6, 7))(manager.getCheckpointRounds(2, 7))
}


private def trainingWithCheckpoint(cacheData: Boolean, skipCleanCheckpoint: Boolean): Unit = {
val eval = new EvalError()
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)

val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
val cacheDataMap = if (cacheData) Map("cacheTrainingSet" -> true) else Map()
val skipCleanCheckpointMap =
if (skipCleanCheckpoint) Map("skip_clean_checkpoint" -> true) else Map()
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
"checkpoint_interval" -> 2, "num_workers" -> numWorkers) ++ cacheDataMap ++
skipCleanCheckpointMap

val prevModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 5)).fit(training)
def error(model: Booster): Float = eval.eval(
model.predict(testDM, outPutMargin = true), testDM)

if (skipCleanCheckpoint) {
// Check only one model is kept after training
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "8.model")
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")
// Train next model based on prev model
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
assert(error(tmpModel) > error(prevModel._booster))
assert(error(prevModel._booster) > error(nextModel._booster))
assert(error(nextModel._booster) < 0.1)
} else {
assert(!FileSystem.get(sc.hadoopConfiguration).exists(new Path(tmpPath)))
}
}

test("training with checkpoint boosters") {
trainingWithCheckpoint(cacheData = false, skipCleanCheckpoint = true)
}

test("training with checkpoint boosters with cached training dataset") {
trainingWithCheckpoint(cacheData = true, skipCleanCheckpoint = true)
}

test("the checkpoint file should be cleaned after a successful training") {
trainingWithCheckpoint(cacheData = false, skipCleanCheckpoint = false)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -179,60 +179,6 @@ class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest {
assert(x < 0.1)
}

test("training with checkpoint boosters") {
val eval = new EvalError()
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)

val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
"checkpoint_interval" -> 2, "num_workers" -> numWorkers)

val prevModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 5)).fit(training)
def error(model: Booster): Float = eval.eval(
model.predict(testDM, outPutMargin = true), testDM)

// Check only one model is kept after training
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "8.model")
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")

// Train next model based on prev model
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
assert(error(tmpModel) > error(prevModel._booster))
assert(error(prevModel._booster) > error(nextModel._booster))
assert(error(nextModel._booster) < 0.1)
}

test("training with checkpoint boosters with cached training dataset") {
val eval = new EvalError()
val training = buildDataFrame(Classification.train)
val testDM = new DMatrix(Classification.test.iterator)

val tmpPath = createTmpFolder("model1").toAbsolutePath.toString
val paramMap = Map("eta" -> "1", "max_depth" -> 2,
"objective" -> "binary:logistic", "checkpoint_path" -> tmpPath,
"checkpoint_interval" -> 2, "num_workers" -> numWorkers, "cacheTrainingSet" -> true)

val prevModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 5)).fit(training)
def error(model: Booster): Float = eval.eval(
model.predict(testDM, outPutMargin = true), testDM)

// Check only one model is kept after training
val files = FileSystem.get(sc.hadoopConfiguration).listStatus(new Path(tmpPath))
assert(files.length == 1)
assert(files.head.getPath.getName == "8.model")
val tmpModel = SXGBoost.loadModel(s"$tmpPath/8.model")

// Train next model based on prev model
val nextModel = new XGBoostClassifier(paramMap ++ Seq("num_round" -> 8)).fit(training)
assert(error(tmpModel) > error(prevModel._booster))
assert(error(prevModel._booster) > error(nextModel._booster))
assert(error(nextModel._booster) < 0.1)
}

test("repartitionForTrainingGroup with group data") {
// test different splits to cover the corner cases.
for (split <- 1 to 20) {
Expand Down

0 comments on commit 7b5cbcc

Please sign in to comment.