@@ -60,11 +60,12 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
6060 def run (input : RDD [LabeledPoint ]): GradientBoostedTreesModel = {
6161 val algo = boostingStrategy.treeStrategy.algo
6262 algo match {
63- case Regression => GradientBoostedTrees .boost(input, boostingStrategy)
63+ case Regression => GradientBoostedTrees .boost(input, input, boostingStrategy, validate = false )
6464 case Classification =>
6565 // Map labels to -1, +1 so binary classification can be treated as regression.
6666 val remappedInput = input.map(x => new LabeledPoint ((x.label * 2 ) - 1 , x.features))
67- GradientBoostedTrees .boost(remappedInput, boostingStrategy)
67+ GradientBoostedTrees .boost(remappedInput,
68+ remappedInput, boostingStrategy, validate= false )
6869 case _ =>
6970 throw new IllegalArgumentException (s " $algo is not supported by the gradient boosting. " )
7071 }
@@ -76,8 +77,42 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
7677 def run (input : JavaRDD [LabeledPoint ]): GradientBoostedTreesModel = {
7778 run(input.rdd)
7879 }
79- }
8080
81+ /**
82+ * Method to validate a gradient boosting model
83+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]].
84+ * @param input Validation dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint ]].
85+ * @return a gradient boosted trees model that can be used for prediction
86+ */
87+ def runWithValidation (
88+ trainInput : RDD [LabeledPoint ],
89+ validateInput : RDD [LabeledPoint ]): GradientBoostedTreesModel = {
90+ val algo = boostingStrategy.treeStrategy.algo
91+ algo match {
92+ case Regression => GradientBoostedTrees .boost(
93+ trainInput, validateInput, boostingStrategy, validate= true )
94+ case Classification =>
95+ // Map labels to -1, +1 so binary classification can be treated as regression.
96+ val remappedTrainInput = trainInput.map(
97+ x => new LabeledPoint ((x.label * 2 ) - 1 , x.features))
98+ val remappedValidateInput = trainInput.map(
99+ x => new LabeledPoint ((x.label * 2 ) - 1 , x.features))
100+ GradientBoostedTrees .boost(remappedTrainInput, remappedValidateInput, boostingStrategy,
101+ validate= true )
102+ case _ =>
103+ throw new IllegalArgumentException (s " $algo is not supported by the gradient boosting. " )
104+ }
105+ }
106+
107+ /**
108+ * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#run ]].
109+ */
110+ def runWithValidation (
111+ trainInput : JavaRDD [LabeledPoint ],
112+ validateInput : JavaRDD [LabeledPoint ]): GradientBoostedTreesModel = {
113+ runWithValidation(trainInput.rdd, validateInput.rdd)
114+ }
115+ }
81116
82117object GradientBoostedTrees extends Logging {
83118
@@ -108,12 +143,16 @@ object GradientBoostedTrees extends Logging {
108143 /**
109144 * Internal method for performing regression using trees as base learners.
110145 * @param input training dataset
146+ * @param validateInput validation dataset, ignored if validate is set to false.
111147 * @param boostingStrategy boosting parameters
148+ * @param validate whether or not to use the validation dataset.
112149 * @return a gradient boosted trees model that can be used for prediction
113150 */
114151 private def boost (
115152 input : RDD [LabeledPoint ],
116- boostingStrategy : BoostingStrategy ): GradientBoostedTreesModel = {
153+ validateInput : RDD [LabeledPoint ],
154+ boostingStrategy : BoostingStrategy ,
155+ validate : Boolean = false ): GradientBoostedTreesModel = {
117156
118157 val timer = new TimeTracker ()
119158 timer.start(" total" )
@@ -129,6 +168,7 @@ object GradientBoostedTrees extends Logging {
129168 val learningRate = boostingStrategy.learningRate
130169 // Prepare strategy for individual trees, which use regression with variance impurity.
131170 val treeStrategy = boostingStrategy.treeStrategy.copy
171+ val validationTol = boostingStrategy.validationTol
132172 treeStrategy.algo = Regression
133173 treeStrategy.impurity = Variance
134174 treeStrategy.assertValid()
@@ -151,14 +191,25 @@ object GradientBoostedTrees extends Logging {
151191 baseLearners(0 ) = firstTreeModel
152192 baseLearnerWeights(0 ) = 1.0
153193 val startingModel = new GradientBoostedTreesModel (Regression , Array (firstTreeModel), Array (1.0 ))
154- logDebug(" error of gbt = " + loss.computeError(startingModel, input))
194+ val errorModel = loss.computeError(startingModel, input)
195+ logDebug(" error of gbt = " + errorModel)
196+
155197 // Note: A model of type regression is used since we require raw prediction
156198 timer.stop(" building tree 0" )
157199
200+ // Just so that it can be accessed below. This error is ignored if validate is set to false.
201+ var prevValidateError = {
202+ if (validate) {
203+ loss.computeError(startingModel, validateInput)
204+ }
205+ else {
206+ errorModel
207+ }
208+ }
209+
158210 // psuedo-residual for second iteration
159211 data = input.map(point => LabeledPoint (loss.gradient(startingModel, point),
160212 point.features))
161-
162213 var m = 1
163214 while (m < numIterations) {
164215 timer.start(s " building tree $m" )
@@ -176,7 +227,24 @@ object GradientBoostedTrees extends Logging {
176227 // Note: A model of type regression is used since we require raw prediction
177228 val partialModel = new GradientBoostedTreesModel (
178229 Regression , baseLearners.slice(0 , m + 1 ), baseLearnerWeights.slice(0 , m + 1 ))
179- logDebug(" error of gbt = " + loss.computeError(partialModel, input))
230+ val errorModel = loss.computeError(partialModel, input)
231+ logDebug(" error of gbt = " + errorModel)
232+
233+ if (validate) {
234+ // Stop training early if
235+ // 1. Reduction in error is lesser than the validationTol or
236+ // 2. If the error increases, that is if the model is overfit.
237+ val currentValidateError = loss.computeError(partialModel, validateInput)
238+ if (prevValidateError - currentValidateError < validationTol) {
239+ return new GradientBoostedTreesModel (
240+ boostingStrategy.treeStrategy.algo,
241+ baseLearners.slice(0 , m),
242+ baseLearnerWeights.slice(0 , m))
243+ }
244+ else {
245+ prevValidateError = currentValidateError
246+ }
247+ }
180248 // Update data with pseudo-residuals
181249 data = input.map(point => LabeledPoint (- loss.gradient(partialModel, point),
182250 point.features))
@@ -191,4 +259,5 @@ object GradientBoostedTrees extends Logging {
191259 new GradientBoostedTreesModel (
192260 boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights)
193261 }
262+
194263}
0 commit comments