Skip to content

Commit 77549a9

Browse files
committed
[SPARK-5436] Validate GradientBoostedTrees using runWithValidation
1 parent 3912d33 commit 77549a9

File tree

3 files changed

+111
-8
lines changed

3 files changed

+111
-8
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala

Lines changed: 76 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,12 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
6060
def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = {
6161
val algo = boostingStrategy.treeStrategy.algo
6262
algo match {
63-
case Regression => GradientBoostedTrees.boost(input, boostingStrategy)
63+
case Regression => GradientBoostedTrees.boost(input, input, boostingStrategy, validate=false)
6464
case Classification =>
6565
// Map labels to -1, +1 so binary classification can be treated as regression.
6666
val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
67-
GradientBoostedTrees.boost(remappedInput, boostingStrategy)
67+
GradientBoostedTrees.boost(remappedInput,
68+
remappedInput, boostingStrategy, validate=false)
6869
case _ =>
6970
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
7071
}
@@ -76,8 +77,42 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
7677
def run(input: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = {
7778
run(input.rdd)
7879
}
79-
}
8080

81+
/**
82+
* Method to validate a gradient boosting model
83+
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
84+
* @param input Validation dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
85+
* @return a gradient boosted trees model that can be used for prediction
86+
*/
87+
def runWithValidation(
88+
trainInput: RDD[LabeledPoint],
89+
validateInput: RDD[LabeledPoint]): GradientBoostedTreesModel = {
90+
val algo = boostingStrategy.treeStrategy.algo
91+
algo match {
92+
case Regression => GradientBoostedTrees.boost(
93+
trainInput, validateInput, boostingStrategy, validate=true)
94+
case Classification =>
95+
// Map labels to -1, +1 so binary classification can be treated as regression.
96+
val remappedTrainInput = trainInput.map(
97+
x => new LabeledPoint((x.label * 2) - 1, x.features))
98+
val remappedValidateInput = trainInput.map(
99+
x => new LabeledPoint((x.label * 2) - 1, x.features))
100+
GradientBoostedTrees.boost(remappedTrainInput, remappedValidateInput, boostingStrategy,
101+
validate=true)
102+
case _ =>
103+
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
104+
}
105+
}
106+
107+
/**
108+
* Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#run]].
109+
*/
110+
def runWithValidation(
111+
trainInput: JavaRDD[LabeledPoint],
112+
validateInput: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = {
113+
runWithValidation(trainInput.rdd, validateInput.rdd)
114+
}
115+
}
81116

82117
object GradientBoostedTrees extends Logging {
83118

@@ -108,12 +143,16 @@ object GradientBoostedTrees extends Logging {
108143
/**
109144
* Internal method for performing regression using trees as base learners.
110145
* @param input training dataset
146+
* @param validateInput validation dataset, ignored if validate is set to false.
111147
* @param boostingStrategy boosting parameters
148+
* @param validate whether or not to use the validation dataset.
112149
* @return a gradient boosted trees model that can be used for prediction
113150
*/
114151
private def boost(
115152
input: RDD[LabeledPoint],
116-
boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
153+
validateInput: RDD[LabeledPoint],
154+
boostingStrategy: BoostingStrategy,
155+
validate: Boolean = false): GradientBoostedTreesModel = {
117156

118157
val timer = new TimeTracker()
119158
timer.start("total")
@@ -129,6 +168,7 @@ object GradientBoostedTrees extends Logging {
129168
val learningRate = boostingStrategy.learningRate
130169
// Prepare strategy for individual trees, which use regression with variance impurity.
131170
val treeStrategy = boostingStrategy.treeStrategy.copy
171+
val validationTol = boostingStrategy.validationTol
132172
treeStrategy.algo = Regression
133173
treeStrategy.impurity = Variance
134174
treeStrategy.assertValid()
@@ -151,14 +191,25 @@ object GradientBoostedTrees extends Logging {
151191
baseLearners(0) = firstTreeModel
152192
baseLearnerWeights(0) = 1.0
153193
val startingModel = new GradientBoostedTreesModel(Regression, Array(firstTreeModel), Array(1.0))
154-
logDebug("error of gbt = " + loss.computeError(startingModel, input))
194+
val errorModel = loss.computeError(startingModel, input)
195+
logDebug("error of gbt = " + errorModel)
196+
155197
// Note: A model of type regression is used since we require raw prediction
156198
timer.stop("building tree 0")
157199

200+
// Just so that it can be accessed below. This error is ignored if validate is set to false.
201+
var prevValidateError = {
202+
if (validate) {
203+
loss.computeError(startingModel, validateInput)
204+
}
205+
else {
206+
errorModel
207+
}
208+
}
209+
158210
// psuedo-residual for second iteration
159211
data = input.map(point => LabeledPoint(loss.gradient(startingModel, point),
160212
point.features))
161-
162213
var m = 1
163214
while (m < numIterations) {
164215
timer.start(s"building tree $m")
@@ -176,7 +227,24 @@ object GradientBoostedTrees extends Logging {
176227
// Note: A model of type regression is used since we require raw prediction
177228
val partialModel = new GradientBoostedTreesModel(
178229
Regression, baseLearners.slice(0, m + 1), baseLearnerWeights.slice(0, m + 1))
179-
logDebug("error of gbt = " + loss.computeError(partialModel, input))
230+
val errorModel = loss.computeError(partialModel, input)
231+
logDebug("error of gbt = " + errorModel)
232+
233+
if (validate) {
234+
// Stop training early if
235+
// 1. Reduction in error is lesser than the validationTol or
236+
// 2. If the error increases, that is if the model is overfit.
237+
val currentValidateError = loss.computeError(partialModel, validateInput)
238+
if (prevValidateError - currentValidateError < validationTol) {
239+
return new GradientBoostedTreesModel(
240+
boostingStrategy.treeStrategy.algo,
241+
baseLearners.slice(0, m),
242+
baseLearnerWeights.slice(0, m))
243+
}
244+
else {
245+
prevValidateError = currentValidateError
246+
}
247+
}
180248
// Update data with pseudo-residuals
181249
data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point),
182250
point.features))
@@ -191,4 +259,5 @@ object GradientBoostedTrees extends Logging {
191259
new GradientBoostedTreesModel(
192260
boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights)
193261
}
262+
194263
}

mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
3434
* weak hypotheses used in the final model.
3535
* @param learningRate Learning rate for shrinking the contribution of each estimator. The
3636
* learning rate should be between in the interval (0, 1]
37+
* @param validationTol Useful when runWithValidation is used. If the error rate between two
38+
iterations is lesser than the validationTol, then stop. If run
39+
is used, then this parameter is ignored.
40+
41+
a pair of RDD's are supplied to run. If the error rate
42+
* between two iterations is lesser than convergenceTol, then training stops.
3743
*/
3844
@Experimental
3945
case class BoostingStrategy(
@@ -42,7 +48,8 @@ case class BoostingStrategy(
4248
@BeanProperty var loss: Loss,
4349
// Optional boosting parameters
4450
@BeanProperty var numIterations: Int = 100,
45-
@BeanProperty var learningRate: Double = 0.1) extends Serializable {
51+
@BeanProperty var learningRate: Double = 0.1,
52+
@BeanProperty var validationTol: Double = 1e-5) extends Serializable {
4653

4754
/**
4855
* Check validity of parameters.
@@ -62,6 +69,7 @@ case class BoostingStrategy(
6269
}
6370
require(learningRate > 0 && learningRate <= 1,
6471
"Learning rate should be in range (0, 1]. Provided learning rate is " + s"$learningRate.")
72+
require(validationTol >= 0, s"validationTol $validationTol should be greater than zero.")
6573
}
6674
}
6775

mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,30 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
158158
}
159159
}
160160
}
161+
162+
test("Early stopping when validation data is provided.") {
163+
// Set numIterations large enough so that it early stops.
164+
val numIterations = 20
165+
val trainRdd = sc.parallelize(GradientBoostedTreesSuite.trainData, 2)
166+
val validateRdd = sc.parallelize(GradientBoostedTreesSuite.validateData, 2)
167+
168+
val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
169+
categoricalFeaturesInfo = Map.empty)
170+
Array(SquaredError, AbsoluteError).foreach { error =>
171+
val boostingStrategy =
172+
new BoostingStrategy(treeStrategy, error, numIterations, validationTol = 0.0)
173+
174+
val gbtValidate = new GradientBoostedTrees(boostingStrategy).runWithValidation(
175+
trainRdd, validateRdd)
176+
assert(gbtValidate.numTrees != numIterations)
177+
178+
val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy)
179+
val errorWithoutValidation = error.computeError(gbt, validateRdd)
180+
val errorWithValidation = error.computeError(gbtValidate, validateRdd)
181+
assert(errorWithValidation < errorWithoutValidation)
182+
}
183+
184+
}
161185
}
162186

163187
private object GradientBoostedTreesSuite {
@@ -166,4 +190,6 @@ private object GradientBoostedTreesSuite {
166190
val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75))
167191

168192
val data = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100)
193+
val trainData = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 120)
194+
val validateData = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80)
169195
}

0 commit comments

Comments
 (0)