Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[jvm-packages] Implemented early stopping #2710

Merged
merged 7 commits into from
Sep 29, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ object BasicWalkThrough {
val watches2 = new mutable.HashMap[String, DMatrix]
watches2 += "train" -> trainMax2
watches2 += "test" -> testMax2
val booster3 = XGBoost.train(trainMax2, params.toMap, round, watches2.toMap, null, null)
val booster3 = XGBoost.train(trainMax2, params.toMap, round, watches2.toMap)
val predicts3 = booster3.predict(testMax2)
println(checkPredicts(predicts, predicts3))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,6 @@ object CrossValidation {
val metrics: Array[String] = null

val evalHist: Array[String] =
XGBoost.crossValidation(trainMat, params.toMap, round, nfold, metrics, null, null)
XGBoost.crossValidation(trainMat, params.toMap, round, nfold, metrics)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ object CustomObjective {
val round = 2
// train a model
val booster = XGBoost.train(trainMat, params.toMap, round, watches.toMap)
XGBoost.train(trainMat, params.toMap, round, watches.toMap, new LogRegObj, new EvalError)
XGBoost.train(trainMat, params.toMap, round, watches.toMap,
obj = new LogRegObj, eval = new EvalError)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,6 @@ object ExternalMemory {
testMat.setBaseMargin(testPred)

System.out.println("result of running from initial prediction")
val booster2 = XGBoost.train(trainMat, params.toMap, 1, watches.toMap, null, null)
val booster2 = XGBoost.train(trainMat, params.toMap, 1, watches.toMap)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ object GeneralizedLinearModel {
watches += "test" -> testMat

val round = 4
val booster = XGBoost.train(trainMat, params.toMap, 1, watches.toMap, null, null)
val booster = XGBoost.train(trainMat, params.toMap, 1, watches.toMap)
val predicts = booster.predict(testMat)
val eval = new CustomEval
println(s"error=${eval.eval(predicts, testMat)}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ object XGBoost {
val trainMat = new DMatrix(dataIter, null)
val watches = List("train" -> trainMat).toMap
val round = 2
val booster = XGBoostScala.train(trainMat, paramMap, round, watches, null, null)
val numEarlyStoppingRounds = paramMap.get("numEarlyStoppingRounds")
.map(_.toString.toInt).getOrElse(0)
val booster = XGBoostScala.train(trainMat, paramMap, round, watches,
earlyStoppingRound = numEarlyStoppingRounds)
Rabit.shutdown()
collector.collect(new XGBoostModel(booster))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package ml.dmlc.xgboost4j.scala.spark

import scala.collection.mutable
import scala.util.Random

import ml.dmlc.xgboost4j.java.{IRabitTracker, Rabit, XGBoostError, RabitTracker => PyRabitTracker}
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
Expand All @@ -25,9 +26,9 @@ import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}

import org.apache.commons.logging.LogFactory
import org.apache.hadoop.fs.{FSDataInputStream, Path}
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset
import org.apache.spark.ml.feature.{LabeledPoint => MLLabeledPoint}
import org.apache.spark.{SparkContext, TaskContext}

object TrackerConf {
Expand Down Expand Up @@ -94,7 +95,7 @@ object XGBoost extends Serializable {
}

private[spark] def buildDistributedBoosters(
trainingSet: RDD[XGBLabeledPoint],
data: RDD[XGBLabeledPoint],
params: Map[String, Any],
rabitEnv: java.util.Map[String, String],
numWorkers: Int,
Expand All @@ -103,19 +104,19 @@ object XGBoost extends Serializable {
eval: EvalTrait,
useExternalMemory: Boolean,
missing: Float): RDD[Booster] = {
val partitionedTrainingSet = if (trainingSet.getNumPartitions != numWorkers) {
val partitionedData = if (data.getNumPartitions != numWorkers) {
logger.info(s"repartitioning training set to $numWorkers partitions")
trainingSet.repartition(numWorkers)
data.repartition(numWorkers)
} else {
trainingSet
data
}
val partitionedBaseMargin = partitionedTrainingSet.map(_.baseMargin)
val appName = partitionedTrainingSet.context.appName
val partitionedBaseMargin = partitionedData.map(_.baseMargin)
val appName = partitionedData.context.appName
// to workaround the empty partitions in training dataset,
// this might not be the best efficient implementation, see
// (https://github.com/dmlc/xgboost/issues/1277)
partitionedTrainingSet.zipPartitions(partitionedBaseMargin) { (trainingPoints, baseMargins) =>
if (trainingPoints.isEmpty) {
partitionedData.zipPartitions(partitionedBaseMargin) { (labeledPoints, baseMargins) =>
if (labeledPoints.isEmpty) {
throw new XGBoostError(
s"detected an empty partition in the training data, partition ID:" +
s" ${TaskContext.getPartitionId()}")
Expand All @@ -128,21 +129,20 @@ object XGBoost extends Serializable {
}
rabitEnv.put("DMLC_TASK_ID", TaskContext.getPartitionId().toString)
Rabit.init(rabitEnv)
val trainingMatrix = new DMatrix(
fromDenseToSparseLabeledPoints(trainingPoints, missing), cacheFileName)
val watches = Watches(params,
fromDenseToSparseLabeledPoints(labeledPoints, missing),
fromBaseMarginsToArray(baseMargins), cacheFileName)

try {
// TODO: use group attribute from the points.
if (params.contains("groupData") && params("groupData") != null) {
trainingMatrix.setGroup(params("groupData").asInstanceOf[Seq[Seq[Int]]](
TaskContext.getPartitionId()).toArray)
}
fromBaseMarginsToArray(baseMargins).foreach(trainingMatrix.setBaseMargin)
val booster = SXGBoost.train(trainingMatrix, params, round,
watches = Map("train" -> trainingMatrix), obj, eval)
val numEarlyStoppingRounds = params.get("numEarlyStoppingRounds")
.map(_.toString.toInt).getOrElse(0)
val booster = SXGBoost.train(watches.train, params, round,
watches = watches.toMap, obj = obj, eval = eval,
earlyStoppingRound = numEarlyStoppingRounds)
Iterator(booster)
} finally {
Rabit.shutdown()
trainingMatrix.delete()
watches.delete()
}
}.cache()
}
Expand Down Expand Up @@ -417,3 +417,46 @@ object XGBoost extends Serializable {
}
}
}

private class Watches private(val train: DMatrix, val test: DMatrix) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if the user wants to monitor more than 2 datasets?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is somewhat tricky to do using spark.ml APIs, so I planned to delay the implementation until there's a request. Otherwise, we could change Watches to wrap a Map and expose train/test as properties.

def toMap: Map[String, DMatrix] = Map("train" -> train, "test" -> test)
.filter { case (_, matrix) => matrix.rowNum > 0 }

def size: Int = toMap.size

def delete(): Unit = {
toMap.values.foreach(_.delete())
}

override def toString: String = toMap.toString
}

private object Watches {
def apply(
params: Map[String, Any],
labeledPoints: Iterator[XGBLabeledPoint],
baseMarginsOpt: Option[Array[Float]],
cacheFileName: String): Watches = {
val trainTestRatio = params.get("trainTestRatio").map(_.toString.toDouble).getOrElse(1.0)
val seed = params.get("seed").map(_.toString.toLong).getOrElse(System.nanoTime())
val r = new Random(seed)
// In the worst-case this would store [[trainTestRatio]] of points
// buffered in memory.
val (trainPoints, testPoints) = labeledPoints.partition(_ => r.nextDouble() <= trainTestRatio)
val trainMatrix = new DMatrix(trainPoints, cacheFileName)
val testMatrix = new DMatrix(testPoints, cacheFileName)
r.setSeed(seed)
for (baseMargins <- baseMarginsOpt) {
val (trainMargin, testMargin) = baseMargins.partition(_ => r.nextDouble() <= trainTestRatio)
trainMatrix.setBaseMargin(trainMargin)
testMatrix.setBaseMargin(testMargin)
}

// TODO: use group attribute from the points.
if (params.contains("groupData") && params("groupData") != null) {
trainMatrix.setGroup(params("groupData").asInstanceOf[Seq[Seq[Int]]](
TaskContext.getPartitionId()).toArray)
}
new Watches(train = trainMatrix, test = testMatrix)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package ml.dmlc.xgboost4j.scala.spark.params

import ml.dmlc.xgboost4j.scala.spark.TrackerConf
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}

import org.apache.spark.ml.param._

trait GeneralParams extends Params {
Expand Down Expand Up @@ -99,9 +99,12 @@ trait GeneralParams extends Params {
*/
val trackerConf = new TrackerConfParam(this, "tracker_conf", "Rabit tracker configurations")

/** Random seed for the C++ part of XGBoost and train/test splitting. */
val seed = new LongParam(this, "seed", "random seed")

setDefault(round -> 1, nWorkers -> 1, numThreadPerTask -> 1,
useExternalMemory -> false, silent -> 0,
customObj -> null, customEval -> null, missing -> Float.NaN,
trackerConf -> TrackerConf()
trackerConf -> TrackerConf(), seed -> 0
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark.params

import scala.collection.immutable.HashSet

import org.apache.spark.ml.param.{DoubleParam, IntParam, Param, Params}
import org.apache.spark.ml.param._

trait LearningTaskParams extends Params {

Expand Down Expand Up @@ -70,8 +70,25 @@ trait LearningTaskParams extends Params {
*/
val weightCol = new Param[String](this, "weightCol", "weight column name")

/**
* Fraction of training points to use for testing.
*/
val trainTestRatio = new DoubleParam(this, "trainTestRatio",
"fraction of training points to use for testing",
ParamValidators.inRange(0, 1))

/**
* If non-zero, the training will be stopped after a specified number
* of consecutive increases in any evaluation metric.
*/
val numEarlyStoppingRounds = new IntParam(this, "numEarlyStoppingRounds",
"number of rounds of decreasing eval metric to tolerate before " +
"stopping the training",
(value: Int) => value == 0 || value > 1)

setDefault(objective -> "reg:linear", baseScore -> 0.5, numClasses -> 2, groupData -> null,
baseMarginCol -> "baseMargin", weightCol -> "weight")
baseMarginCol -> "baseMargin", weightCol -> "weight", trainTestRatio -> 1.0,
numEarlyStoppingRounds -> 0)
}

private[spark] object LearningTaskParams {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark

import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost}
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}

import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql._
Expand Down Expand Up @@ -201,7 +202,8 @@ class XGBoostDFSuite extends FunSuite with PerTest {
val trainingDfWithMargin = trainingDf.withColumn("margin", functions.rand())
val testRDD = sc.parallelize(Classification.test.map(_.features))
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
"objective" -> "binary:logistic", "baseMarginCol" -> "margin")
"objective" -> "binary:logistic", "baseMarginCol" -> "margin",
"testTrainSplit" -> 0.5)

def trainPredict(df: Dataset[_]): Array[Float] = {
XGBoost.trainWithDataFrame(df, paramMap, round = 1, nWorkers = numWorkers)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,13 +201,20 @@ public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter, float
*/
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eval)
throws XGBoostError {
// Hopefully, a tiny redundant allocation wouldn't hurt.
return evalSet(evalMatrixs, evalNames, eval, new float[evalNames.length]);
}

public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eval,
float[] metricsOut) throws XGBoostError {
String evalInfo = "";
for (int i = 0; i < evalNames.length; i++) {
String evalName = evalNames[i];
DMatrix evalMat = evalMatrixs[i];
float evalResult = eval.eval(predict(evalMat), evalMat);
String evalMetric = eval.getMetric();
evalInfo += String.format("\t%s-%s:%f", evalName, evalMetric, evalResult);
metricsOut[i] = evalResult;
}
return evalInfo;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public static Booster train(
Map<String, DMatrix> watches,
IObjective obj,
IEvaluation eval) throws XGBoostError {
return train(dtrain, params, round, watches, null, obj, eval);
return train(dtrain, params, round, watches, null, obj, eval, 0);
}

public static Booster train(
Expand All @@ -74,7 +74,8 @@ public static Booster train(
Map<String, DMatrix> watches,
float[][] metrics,
IObjective obj,
IEvaluation eval) throws XGBoostError {
IEvaluation eval,
int earlyStoppingRound) throws XGBoostError {

//collect eval matrixs
String[] evalNames;
Expand All @@ -89,6 +90,7 @@ public static Booster train(

evalNames = names.toArray(new String[names.size()]);
evalMats = mats.toArray(new DMatrix[mats.size()]);
metrics = metrics == null ? new float[evalNames.length][round] : metrics;

//collect all data matrixs
DMatrix[] allMats;
Expand Down Expand Up @@ -120,19 +122,27 @@ public static Booster train(

//evaluation
if (evalMats.length > 0) {
float[] metricsOut = new float[evalMats.length];
String evalInfo;
if (eval != null) {
evalInfo = booster.evalSet(evalMats, evalNames, eval);
evalInfo = booster.evalSet(evalMats, evalNames, eval, metricsOut);
} else {
if (metrics == null) {
evalInfo = booster.evalSet(evalMats, evalNames, iter);
} else {
float[] m = new float[evalMats.length];
evalInfo = booster.evalSet(evalMats, evalNames, iter, m);
for (int i = 0; i < m.length; i++) {
metrics[i][iter] = m[i];
}
}
evalInfo = booster.evalSet(evalMats, evalNames, iter, metricsOut);
}
for (int i = 0; i < metricsOut.length; i++) {
metrics[i][iter] = metricsOut[i];
}

boolean decreasing = true;
float[] criterion = metrics[metrics.length - 1];
for (int shift = 0; shift < Math.min(iter, earlyStoppingRound) - 1; shift++) {
decreasing &= criterion[iter - shift] <= criterion[iter - shift - 1];
}

if (!decreasing) {
Rabit.trackerPrint(String.format(
"early stopping after %d decreasing rounds", earlyStoppingRound));
break;
}
if (Rabit.getRank() == 0) {
Rabit.trackerPrint(evalInfo + '\n');
Expand Down
Loading