Skip to content

Commit 2f8e4d0

Browse files
huaxingaodongjoon-hyun
authored andcommitted
[SPARK-30630][ML] Remove numTrees in GBT in 3.0.0
### What changes were proposed in this pull request? Remove ```numTrees``` in GBT in 3.0.0. ### Why are the changes needed? Currently, GBT has ``` /** * Number of trees in ensemble */ Since("2.0.0") val getNumTrees: Int = trees.length ``` and ``` /** Number of trees in ensemble */ val numTrees: Int = trees.length ``` I think we should remove one of them. We deprecated it in 2.4.5 via #27352. ### Does this PR introduce any user-facing change? Yes, remove ```numTrees``` in GBT in 3.0.0 ### How was this patch tested? existing tests Closes #27330 from huaxingao/spark-numTrees. Authored-by: Huaxin Gao <huaxing@us.ibm.com> Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
1 parent f86a1b9 commit 2f8e4d0

File tree

5 files changed

+20
-21
lines changed

5 files changed

+20
-21
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -342,9 +342,6 @@ class GBTClassificationModel private[ml](
342342
}
343343
}
344344

345-
/** Number of trees in ensemble */
346-
val numTrees: Int = trees.length
347-
348345
@Since("1.4.0")
349346
override def copy(extra: ParamMap): GBTClassificationModel = {
350347
copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures, numClasses),
@@ -353,7 +350,7 @@ class GBTClassificationModel private[ml](
353350

354351
@Since("1.4.0")
355352
override def toString: String = {
356-
s"GBTClassificationModel: uid = $uid, numTrees=$numTrees, numClasses=$numClasses, " +
353+
s"GBTClassificationModel: uid = $uid, numTrees=$getNumTrees, numClasses=$numClasses, " +
357354
s"numFeatures=$numFeatures"
358355
}
359356

@@ -374,7 +371,7 @@ class GBTClassificationModel private[ml](
374371
/** Raw prediction for the positive class. */
375372
private def margin(features: Vector): Double = {
376373
val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
377-
blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
374+
blas.ddot(getNumTrees, treePredictions, 1, _treeWeights, 1)
378375
}
379376

380377
/** (private[ml]) Convert to a model in the old API */

mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -299,12 +299,9 @@ class GBTRegressionModel private[ml](
299299
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
300300
// Classifies by thresholding sum of weighted tree predictions
301301
val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
302-
blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
302+
blas.ddot(getNumTrees, treePredictions, 1, _treeWeights, 1)
303303
}
304304

305-
/** Number of trees in ensemble */
306-
val numTrees: Int = trees.length
307-
308305
@Since("1.4.0")
309306
override def copy(extra: ParamMap): GBTRegressionModel = {
310307
copyValues(new GBTRegressionModel(uid, _trees, _treeWeights, numFeatures),
@@ -313,7 +310,7 @@ class GBTRegressionModel private[ml](
313310

314311
@Since("1.4.0")
315312
override def toString: String = {
316-
s"GBTRegressionModel: uid=$uid, numTrees=$numTrees, numFeatures=$numFeatures"
313+
s"GBTRegressionModel: uid=$uid, numTrees=$getNumTrees, numFeatures=$numFeatures"
317314
}
318315

319316
/**

mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,8 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
179179
assert(raw.size === 2)
180180
// check that raw prediction is tree predictions dot tree weights
181181
val treePredictions = gbtModel.trees.map(_.rootNode.predictImpl(features).prediction)
182-
val prediction = blas.ddot(gbtModel.numTrees, treePredictions, 1, gbtModel.treeWeights, 1)
182+
val prediction = blas.ddot(gbtModel.getNumTrees, treePredictions, 1,
183+
gbtModel.treeWeights, 1)
183184
assert(raw ~== Vectors.dense(-prediction, prediction) relTol eps)
184185

185186
// Compare rawPrediction with probability
@@ -436,9 +437,9 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
436437
gbt.setValidationIndicatorCol(validationIndicatorCol)
437438
val modelWithValidation = gbt.fit(trainDF.union(validationDF))
438439

439-
assert(modelWithoutValidation.numTrees === numIter)
440+
assert(modelWithoutValidation.getNumTrees === numIter)
440441
// early stop
441-
assert(modelWithValidation.numTrees < numIter)
442+
assert(modelWithValidation.getNumTrees < numIter)
442443

443444
val (errorWithoutValidation, errorWithValidation) = {
444445
val remappedRdd = validationData.map {
@@ -457,10 +458,10 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
457458
modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType,
458459
OldAlgo.Classification)
459460
assert(evaluationArray.length === numIter)
460-
assert(evaluationArray(modelWithValidation.numTrees) >
461-
evaluationArray(modelWithValidation.numTrees - 1))
461+
assert(evaluationArray(modelWithValidation.getNumTrees) >
462+
evaluationArray(modelWithValidation.getNumTrees - 1))
462463
var i = 1
463-
while (i < modelWithValidation.numTrees) {
464+
while (i < modelWithValidation.getNumTrees) {
464465
assert(evaluationArray(i) <= evaluationArray(i - 1))
465466
i += 1
466467
}

mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -274,9 +274,9 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest {
274274
gbt.setValidationIndicatorCol(validationIndicatorCol)
275275
val modelWithValidation = gbt.fit(trainDF.union(validationDF))
276276

277-
assert(modelWithoutValidation.numTrees === numIter)
277+
assert(modelWithoutValidation.getNumTrees === numIter)
278278
// early stop
279-
assert(modelWithValidation.numTrees < numIter)
279+
assert(modelWithValidation.getNumTrees < numIter)
280280

281281
val errorWithoutValidation = GradientBoostedTrees.computeWeightedError(
282282
validationData.map(_.toInstance),
@@ -294,10 +294,10 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest {
294294
modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType,
295295
OldAlgo.Regression)
296296
assert(evaluationArray.length === numIter)
297-
assert(evaluationArray(modelWithValidation.numTrees) >
298-
evaluationArray(modelWithValidation.numTrees - 1))
297+
assert(evaluationArray(modelWithValidation.getNumTrees) >
298+
evaluationArray(modelWithValidation.getNumTrees - 1))
299299
var i = 1
300-
while (i < modelWithValidation.numTrees) {
300+
while (i < modelWithValidation.getNumTrees) {
301301
assert(evaluationArray(i) <= evaluationArray(i - 1))
302302
i += 1
303303
}

project/MimaExcludes.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,10 @@ object MimaExcludes {
344344
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel.layers"),
345345
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel.this"),
346346

347+
// [SPARK-30630][ML] Remove numTrees in GBT
348+
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.numTrees"),
349+
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.numTrees"),
350+
347351
// Data Source V2 API changes
348352
(problem: Problem) => problem match {
349353
case MissingClassProblem(cls) =>

0 commit comments

Comments
 (0)