diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 62cfa39746ff..c5cb03e55019 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -319,7 +319,12 @@ class GBTClassificationModel private[ml]( } } - /** Number of trees in ensemble */ + /** + * Number of trees in ensemble + * + * @deprecated Use [[getNumTrees]] instead. This method will be removed in 3.0.0. + */ + @deprecated("Use getNumTrees instead. This method will be removed in 3.0.0.", "2.4.5") val numTrees: Int = trees.length @Since("1.4.0") @@ -330,7 +335,7 @@ class GBTClassificationModel private[ml]( @Since("1.4.0") override def toString: String = { - s"GBTClassificationModel (uid=$uid) with $numTrees trees" + s"GBTClassificationModel (uid=$uid) with $getNumTrees trees" } /** @@ -349,7 +354,7 @@ class GBTClassificationModel private[ml]( /** Raw prediction for the positive class. */ private def margin(features: Vector): Double = { val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction) - blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) + blas.ddot(getNumTrees, treePredictions, 1, _treeWeights, 1) } /** (private[ml]) Convert to a model in the old API */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 07f88d8d5f84..a56b5c4d83bd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -255,10 +255,15 @@ class GBTRegressionModel private[ml]( // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 // Classifies by thresholding sum of weighted tree predictions val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction) - blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) + blas.ddot(getNumTrees, treePredictions, 1, _treeWeights, 1) } - /** Number of trees in ensemble */ + /** + * Number of trees in ensemble + * + * @deprecated Use [[getNumTrees]] instead. This method will be removed in 3.0.0. + */ + @deprecated("Use getNumTrees instead. This method will be removed in 3.0.0.", "2.4.5") val numTrees: Int = trees.length @Since("1.4.0") @@ -269,7 +274,7 @@ class GBTRegressionModel private[ml]( @Since("1.4.0") override def toString: String = { - s"GBTRegressionModel (uid=$uid) with $numTrees trees" + s"GBTRegressionModel (uid=$uid) with $getNumTrees trees" } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 304977634189..02b6b5d3ef49 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -178,7 +178,8 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { assert(raw.size === 2) // check that raw prediction is tree predictions dot tree weights val treePredictions = gbtModel.trees.map(_.rootNode.predictImpl(features).prediction) - val prediction = blas.ddot(gbtModel.numTrees, treePredictions, 1, gbtModel.treeWeights, 1) + val prediction = blas.ddot(gbtModel.getNumTrees, treePredictions, 1, + gbtModel.treeWeights, 1) assert(raw ~== Vectors.dense(-prediction, prediction) relTol eps) // Compare rawPrediction with probability @@ -410,9 +411,9 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { gbt.setValidationIndicatorCol(validationIndicatorCol) val modelWithValidation = gbt.fit(trainDF.union(validationDF)) - assert(modelWithoutValidation.numTrees === numIter) + assert(modelWithoutValidation.getNumTrees === numIter) // early stop - assert(modelWithValidation.numTrees < numIter) + assert(modelWithValidation.getNumTrees < numIter) val (errorWithoutValidation, errorWithValidation) = { val remappedRdd = validationData.map(x => new LabeledPoint(2 * x.label - 1, x.features)) @@ -428,10 +429,10 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType, OldAlgo.Classification) assert(evaluationArray.length === numIter) - assert(evaluationArray(modelWithValidation.numTrees) > - evaluationArray(modelWithValidation.numTrees - 1)) + assert(evaluationArray(modelWithValidation.getNumTrees) > + evaluationArray(modelWithValidation.getNumTrees - 1)) var i = 1 - while (i < modelWithValidation.numTrees) { + while (i < modelWithValidation.getNumTrees) { assert(evaluationArray(i) <= evaluationArray(i - 1)) i += 1 } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index b145c7a3dc95..9342bc031fd2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -249,9 +249,9 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest { gbt.setValidationIndicatorCol(validationIndicatorCol) val modelWithValidation = gbt.fit(trainDF.union(validationDF)) - assert(modelWithoutValidation.numTrees === numIter) + assert(modelWithoutValidation.getNumTrees === numIter) // early stop - assert(modelWithValidation.numTrees < numIter) + assert(modelWithValidation.getNumTrees < numIter) val errorWithoutValidation = GradientBoostedTrees.computeError(validationData, modelWithoutValidation.trees, modelWithoutValidation.treeWeights, @@ -267,10 +267,10 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest { modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType, OldAlgo.Regression) assert(evaluationArray.length === numIter) - assert(evaluationArray(modelWithValidation.numTrees) > - evaluationArray(modelWithValidation.numTrees - 1)) + assert(evaluationArray(modelWithValidation.getNumTrees) > + evaluationArray(modelWithValidation.getNumTrees - 1)) var i = 1 - while (i < modelWithValidation.numTrees) { + while (i < modelWithValidation.getNumTrees) { assert(evaluationArray(i) <= evaluationArray(i - 1)) i += 1 }