diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 4044c0878921..46810bccc8e6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -342,9 +342,6 @@ class GBTClassificationModel private[ml]( } } - /** Number of trees in ensemble */ - val numTrees: Int = trees.length - @Since("1.4.0") override def copy(extra: ParamMap): GBTClassificationModel = { copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures, numClasses), @@ -353,7 +350,7 @@ class GBTClassificationModel private[ml]( @Since("1.4.0") override def toString: String = { - s"GBTClassificationModel: uid = $uid, numTrees=$numTrees, numClasses=$numClasses, " + + s"GBTClassificationModel: uid = $uid, numTrees=$getNumTrees, numClasses=$numClasses, " + s"numFeatures=$numFeatures" } @@ -374,7 +371,7 @@ class GBTClassificationModel private[ml]( /** Raw prediction for the positive class. */ private def margin(features: Vector): Double = { val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction) - blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) + blas.ddot(getNumTrees, treePredictions, 1, _treeWeights, 1) } /** (private[ml]) Convert to a model in the old API */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 29991f59e37c..2c2558f00bb1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -300,12 +300,9 @@ class GBTRegressionModel private[ml]( // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 // Classifies by thresholding sum of weighted tree predictions val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction) - blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) + blas.ddot(getNumTrees, treePredictions, 1, _treeWeights, 1) } - /** Number of trees in ensemble */ - val numTrees: Int = trees.length - @Since("1.4.0") override def copy(extra: ParamMap): GBTRegressionModel = { copyValues(new GBTRegressionModel(uid, _trees, _treeWeights, numFeatures), @@ -314,7 +311,7 @@ class GBTRegressionModel private[ml]( @Since("1.4.0") override def toString: String = { - s"GBTRegressionModel: uid=$uid, numTrees=$numTrees, numFeatures=$numFeatures" + s"GBTRegressionModel: uid=$uid, numTrees=$getNumTrees, numFeatures=$numFeatures" } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index abeb4b5331ac..a2208edcb839 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -179,7 +179,8 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { assert(raw.size === 2) // check that raw prediction is tree predictions dot tree weights val treePredictions = gbtModel.trees.map(_.rootNode.predictImpl(features).prediction) - val prediction = blas.ddot(gbtModel.numTrees, treePredictions, 1, gbtModel.treeWeights, 1) + val prediction = blas.ddot(gbtModel.getNumTrees, treePredictions, 1, + gbtModel.treeWeights, 1) assert(raw ~== Vectors.dense(-prediction, prediction) relTol eps) // Compare rawPrediction with probability @@ -436,9 +437,9 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { gbt.setValidationIndicatorCol(validationIndicatorCol) val modelWithValidation = gbt.fit(trainDF.union(validationDF)) - assert(modelWithoutValidation.numTrees === numIter) + assert(modelWithoutValidation.getNumTrees === numIter) // early stop - assert(modelWithValidation.numTrees < numIter) + assert(modelWithValidation.getNumTrees < numIter) val (errorWithoutValidation, errorWithValidation) = { val remappedRdd = validationData.map { @@ -457,10 +458,10 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType, OldAlgo.Classification) assert(evaluationArray.length === numIter) - assert(evaluationArray(modelWithValidation.numTrees) > - evaluationArray(modelWithValidation.numTrees - 1)) + assert(evaluationArray(modelWithValidation.getNumTrees) > + evaluationArray(modelWithValidation.getNumTrees - 1)) var i = 1 - while (i < modelWithValidation.numTrees) { + while (i < modelWithValidation.getNumTrees) { assert(evaluationArray(i) <= evaluationArray(i - 1)) i += 1 } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 35c0fc9b02b1..04b0d4b8470f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -274,9 +274,9 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest { gbt.setValidationIndicatorCol(validationIndicatorCol) val modelWithValidation = gbt.fit(trainDF.union(validationDF)) - assert(modelWithoutValidation.numTrees === numIter) + assert(modelWithoutValidation.getNumTrees === numIter) // early stop - assert(modelWithValidation.numTrees < numIter) + assert(modelWithValidation.getNumTrees < numIter) val errorWithoutValidation = GradientBoostedTrees.computeWeightedError( validationData.map(_.toInstance), @@ -294,10 +294,10 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest { modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType, OldAlgo.Regression) assert(evaluationArray.length === numIter) - assert(evaluationArray(modelWithValidation.numTrees) > - evaluationArray(modelWithValidation.numTrees - 1)) + assert(evaluationArray(modelWithValidation.getNumTrees) > + evaluationArray(modelWithValidation.getNumTrees - 1)) var i = 1 - while (i < modelWithValidation.numTrees) { + while (i < modelWithValidation.getNumTrees) { assert(evaluationArray(i) <= evaluationArray(i - 1)) i += 1 } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index f1bbe0b10e22..68e9313805e1 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -344,6 +344,10 @@ object MimaExcludes { ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel.layers"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel.this"), + // [SPARK-30630][ML] Remove numTrees in GBT + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.numTrees"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.numTrees"), + // Data Source V2 API changes (problem: Problem) => problem match { case MissingClassProblem(cls) =>