diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index c9bbd37a6736..ade0960f87a0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -23,9 +23,8 @@ import org.json4s.JsonDSL._ import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging -import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.feature.LabeledPoint -import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree._ @@ -33,6 +32,7 @@ import org.apache.spark.ml.tree.impl.GradientBoostedTrees import org.apache.spark.ml.util._ import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.tree.loss.LogLoss import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} @@ -58,7 +58,7 @@ import org.apache.spark.sql.functions._ @Since("1.4.0") class GBTClassifier @Since("1.4.0") ( @Since("1.4.0") override val uid: String) - extends Predictor[Vector, GBTClassifier, GBTClassificationModel] + extends ProbabilisticClassifier[Vector, GBTClassifier, GBTClassificationModel] with GBTClassifierParams with DefaultParamsWritable with Logging { @Since("1.4.0") @@ -158,12 +158,19 @@ class GBTClassifier @Since("1.4.0") ( val numFeatures = oldDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) + val numClasses = 2 + if (isDefined(thresholds)) { + require($(thresholds).length == numClasses, this.getClass.getSimpleName + + ".train() called with non-matching numClasses and thresholds.length." + + s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") + } + val instr = Instrumentation.create(this, oldDataset) instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval) instr.logNumFeatures(numFeatures) - instr.logNumClasses(2) + instr.logNumClasses(numClasses) val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, $(seed)) @@ -202,8 +209,9 @@ class GBTClassificationModel private[ml]( @Since("1.6.0") override val uid: String, private val _trees: Array[DecisionTreeRegressionModel], private val _treeWeights: Array[Double], - @Since("1.6.0") override val numFeatures: Int) - extends PredictionModel[Vector, GBTClassificationModel] + @Since("1.6.0") override val numFeatures: Int, + @Since("2.2.0") override val numClasses: Int) + extends ProbabilisticClassificationModel[Vector, GBTClassificationModel] with GBTClassifierParams with TreeEnsembleModel[DecisionTreeRegressionModel] with MLWritable with Serializable { @@ -211,6 +219,20 @@ class GBTClassificationModel private[ml]( require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" + s" of non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).") + /** + * Construct a GBTClassificationModel + * + * @param _trees Decision trees in the ensemble. + * @param _treeWeights Weights for the decision trees in the ensemble. + * @param numFeatures The number of features. + */ + private[ml] def this( + uid: String, + _trees: Array[DecisionTreeRegressionModel], + _treeWeights: Array[Double], + numFeatures: Int) = + this(uid, _trees, _treeWeights, numFeatures, 2) + /** * Construct a GBTClassificationModel * @@ -219,7 +241,7 @@ class GBTClassificationModel private[ml]( */ @Since("1.6.0") def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) = - this(uid, _trees, _treeWeights, -1) + this(uid, _trees, _treeWeights, -1, 2) @Since("1.4.0") override def trees: Array[DecisionTreeRegressionModel] = _trees @@ -242,11 +264,29 @@ class GBTClassificationModel private[ml]( } override protected def predict(features: Vector): Double = { - // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 - // Classifies by thresholding sum of weighted tree predictions - val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction) - val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) - if (prediction > 0.0) 1.0 else 0.0 + // If thresholds defined, use predictRaw to get probabilities, otherwise use optimization + if (isDefined(thresholds)) { + super.predict(features) + } else { + if (margin(features) > 0.0) 1.0 else 0.0 + } + } + + override protected def predictRaw(features: Vector): Vector = { + val prediction: Double = margin(features) + Vectors.dense(Array(-prediction, prediction)) + } + + override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { + rawPrediction match { + case dv: DenseVector => + dv.values(0) = loss.computeProbability(dv.values(0)) + dv.values(1) = 1.0 - dv.values(0) + dv + case sv: SparseVector => + throw new RuntimeException("Unexpected error in GBTClassificationModel:" + + " raw2probabilityInPlace encountered SparseVector") + } } /** Number of trees in ensemble */ @@ -254,7 +294,7 @@ class GBTClassificationModel private[ml]( @Since("1.4.0") override def copy(extra: ParamMap): GBTClassificationModel = { - copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures), + copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures, numClasses), extra).setParent(parent) } @@ -276,11 +316,20 @@ class GBTClassificationModel private[ml]( @Since("2.0.0") lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures) + /** Raw prediction for the positive class. */ + private def margin(features: Vector): Double = { + val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction) + blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) + } + /** (private[ml]) Convert to a model in the old API */ private[ml] def toOld: OldGBTModel = { new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights) } + // hard coded loss, which is not meant to be changed in the model + private val loss = getOldLossType + @Since("2.0.0") override def write: MLWriter = new GBTClassificationModel.GBTClassificationModelWriter(this) } @@ -288,6 +337,9 @@ class GBTClassificationModel private[ml]( @Since("2.0.0") object GBTClassificationModel extends MLReadable[GBTClassificationModel] { + private val numFeaturesKey: String = "numFeatures" + private val numTreesKey: String = "numTrees" + @Since("2.0.0") override def read: MLReader[GBTClassificationModel] = new GBTClassificationModelReader @@ -300,8 +352,8 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { override protected def saveImpl(path: String): Unit = { val extraMetadata: JObject = Map( - "numFeatures" -> instance.numFeatures, - "numTrees" -> instance.getNumTrees) + numFeaturesKey -> instance.numFeatures, + numTreesKey -> instance.getNumTrees) EnsembleModelReadWrite.saveImpl(instance, path, sparkSession, extraMetadata) } } @@ -316,8 +368,8 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) - val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] - val numTrees = (metadata.metadata \ "numTrees").extract[Int] + val numFeatures = (metadata.metadata \ numFeaturesKey).extract[Int] + val numTrees = (metadata.metadata \ numTreesKey).extract[Int] val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => @@ -328,7 +380,8 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { } require(numTrees == trees.length, s"GBTClassificationModel.load expected $numTrees" + s" trees based on metadata but found ${trees.length} trees.") - val model = new GBTClassificationModel(metadata.uid, trees, treeWeights, numFeatures) + val model = new GBTClassificationModel(metadata.uid, + trees, treeWeights, numFeatures) DefaultParamsReader.getAndSetParams(model, metadata) model } @@ -339,7 +392,8 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { oldModel: OldGBTModel, parent: GBTClassifier, categoricalFeatures: Map[Int, Int], - numFeatures: Int = -1): GBTClassificationModel = { + numFeatures: Int = -1, + numClasses: Int = 2): GBTClassificationModel = { require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" + s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).") val newTrees = oldModel.trees.map { tree => @@ -347,6 +401,6 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc") - new GBTClassificationModel(uid, newTrees, oldModel.treeWeights, numFeatures) + new GBTClassificationModel(uid, newTrees, oldModel.treeWeights, numFeatures, numClasses) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index c7a8f76eca84..5eb707dfe7bc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -25,7 +25,7 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} -import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError} +import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, ClassificationLoss => OldClassificationLoss, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError} import org.apache.spark.sql.types.{DataType, DoubleType, StructType} /** @@ -531,7 +531,7 @@ private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParam def getLossType: String = $(lossType).toLowerCase /** (private[ml]) Convert new loss to old loss. */ - override private[ml] def getOldLossType: OldLoss = { + override private[ml] def getOldLossType: OldClassificationLoss = { getLossType match { case "logistic" => OldLogLoss case _ => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala index 5d92ce495b04..9339f0a23c1b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala @@ -20,7 +20,6 @@ package org.apache.spark.mllib.tree.loss import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.util.MLUtils - /** * :: DeveloperApi :: * Class for log loss calculation (for classification). @@ -32,7 +31,7 @@ import org.apache.spark.mllib.util.MLUtils */ @Since("1.2.0") @DeveloperApi -object LogLoss extends Loss { +object LogLoss extends ClassificationLoss { /** * Method to calculate the loss gradients for the gradient boosting calculation for binary @@ -52,4 +51,11 @@ object LogLoss extends Loss { // The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable. 2.0 * MLUtils.log1pExp(-margin) } + + /** + * Returns the estimated probability of a label of 1.0. + */ + override private[spark] def computeProbability(margin: Double): Double = { + 1.0 / (1.0 + math.exp(-2.0 * margin)) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala index 09274a2e1b2a..e7ffb3f8f53c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala @@ -22,7 +22,6 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.rdd.RDD - /** * :: DeveloperApi :: * Trait for adding "pluggable" loss functions for the gradient boosting algorithm. @@ -67,3 +66,10 @@ trait Loss extends Serializable { */ private[spark] def computeError(prediction: Double, label: Double): Double } + +private[spark] trait ClassificationLoss extends Loss { + /** + * Computes the class probability given the margin. + */ + private[spark] def computeProbability(margin: Double): Double +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 7c36745ab213..0598943c3d4b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -17,20 +17,24 @@ package org.apache.spark.ml.classification +import com.github.fommil.netlib.BLAS + import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.feature.LabeledPoint -import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree.LeafNode import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.tree.loss.LogLoss import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.util.Utils /** @@ -49,6 +53,8 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext private var data: RDD[LabeledPoint] = _ private var trainData: RDD[LabeledPoint] = _ private var validationData: RDD[LabeledPoint] = _ + private val eps: Double = 1e-5 + private val absEps: Double = 1e-8 override def beforeAll() { super.beforeAll() @@ -66,10 +72,156 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext ParamsSuite.checkParams(new GBTClassifier) val model = new GBTClassificationModel("gbtc", Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null), 1)), - Array(1.0), 1) + Array(1.0), 1, 2) ParamsSuite.checkParams(model) } + test("GBTClassifier: default params") { + val gbt = new GBTClassifier + assert(gbt.getLabelCol === "label") + assert(gbt.getFeaturesCol === "features") + assert(gbt.getPredictionCol === "prediction") + assert(gbt.getRawPredictionCol === "rawPrediction") + assert(gbt.getProbabilityCol === "probability") + val df = trainData.toDF() + val model = gbt.fit(df) + model.transform(df) + .select("label", "probability", "prediction", "rawPrediction") + .collect() + intercept[NoSuchElementException] { + model.getThresholds + } + assert(model.getFeaturesCol === "features") + assert(model.getPredictionCol === "prediction") + assert(model.getRawPredictionCol === "rawPrediction") + assert(model.getProbabilityCol === "probability") + assert(model.hasParent) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + } + + test("setThreshold, getThreshold") { + val gbt = new GBTClassifier + + // default + withClue("GBTClassifier should not have thresholds set by default.") { + intercept[NoSuchElementException] { + gbt.getThresholds + } + } + + // Set via thresholds + val gbt2 = new GBTClassifier + val threshold = Array(0.3, 0.7) + gbt2.setThresholds(threshold) + assert(gbt2.getThresholds === threshold) + } + + test("thresholds prediction") { + val gbt = new GBTClassifier + val df = trainData.toDF() + val binaryModel = gbt.fit(df) + + // should predict all zeros + binaryModel.setThresholds(Array(0.0, 1.0)) + val binaryZeroPredictions = binaryModel.transform(df).select("prediction").collect() + assert(binaryZeroPredictions.forall(_.getDouble(0) === 0.0)) + + // should predict all ones + binaryModel.setThresholds(Array(1.0, 0.0)) + val binaryOnePredictions = binaryModel.transform(df).select("prediction").collect() + assert(binaryOnePredictions.forall(_.getDouble(0) === 1.0)) + + + val gbtBase = new GBTClassifier + val model = gbtBase.fit(df) + val basePredictions = model.transform(df).select("prediction").collect() + + // constant threshold scaling is the same as no thresholds + binaryModel.setThresholds(Array(1.0, 1.0)) + val scaledPredictions = binaryModel.transform(df).select("prediction").collect() + assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) => + scaled.getDouble(0) === base.getDouble(0) + }) + + // force it to use the predict method + model.setRawPredictionCol("").setProbabilityCol("").setThresholds(Array(0, 1)) + val predictionsWithPredict = model.transform(df).select("prediction").collect() + assert(predictionsWithPredict.forall(_.getDouble(0) === 0.0)) + } + + test("GBTClassifier: Predictor, Classifier methods") { + val rawPredictionCol = "rawPrediction" + val predictionCol = "prediction" + val labelCol = "label" + val featuresCol = "features" + val probabilityCol = "probability" + + val gbt = new GBTClassifier().setSeed(123) + val trainingDataset = trainData.toDF(labelCol, featuresCol) + val gbtModel = gbt.fit(trainingDataset) + assert(gbtModel.numClasses === 2) + val numFeatures = trainingDataset.select(featuresCol).first().getAs[Vector](0).size + assert(gbtModel.numFeatures === numFeatures) + + val blas = BLAS.getInstance() + + val validationDataset = validationData.toDF(labelCol, featuresCol) + val results = gbtModel.transform(validationDataset) + // check that raw prediction is tree predictions dot tree weights + results.select(rawPredictionCol, featuresCol).collect().foreach { + case Row(raw: Vector, features: Vector) => + assert(raw.size === 2) + val treePredictions = gbtModel.trees.map(_.rootNode.predictImpl(features).prediction) + val prediction = blas.ddot(gbtModel.numTrees, treePredictions, 1, gbtModel.treeWeights, 1) + assert(raw ~== Vectors.dense(-prediction, prediction) relTol eps) + } + + // Compare rawPrediction with probability + results.select(rawPredictionCol, probabilityCol).collect().foreach { + case Row(raw: Vector, prob: Vector) => + assert(raw.size === 2) + assert(prob.size === 2) + // Note: we should check other loss types for classification if they are added + val predFromRaw = raw.toDense.values.map(value => LogLoss.computeProbability(value)) + assert(prob(0) ~== predFromRaw(0) relTol eps) + assert(prob(1) ~== predFromRaw(1) relTol eps) + assert(prob(0) + prob(1) ~== 1.0 absTol absEps) + } + + // Compare prediction with probability + results.select(predictionCol, probabilityCol).collect().foreach { + case Row(pred: Double, prob: Vector) => + val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2 + assert(pred == predFromProb) + } + + // force it to use raw2prediction + gbtModel.setRawPredictionCol(rawPredictionCol).setProbabilityCol("") + val resultsUsingRaw2Predict = + gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect() + resultsUsingRaw2Predict.zip(results.select(predictionCol).as[Double].collect()).foreach { + case (pred1, pred2) => assert(pred1 === pred2) + } + + // force it to use probability2prediction + gbtModel.setRawPredictionCol("").setProbabilityCol(probabilityCol) + val resultsUsingProb2Predict = + gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect() + resultsUsingProb2Predict.zip(results.select(predictionCol).as[Double].collect()).foreach { + case (pred1, pred2) => assert(pred1 === pred2) + } + + // force it to use predict + gbtModel.setRawPredictionCol("").setProbabilityCol("") + val resultsUsingPredict = + gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect() + resultsUsingPredict.zip(results.select(predictionCol).as[Double].collect()).foreach { + case (pred1, pred2) => assert(pred1 === pred2) + } + } + test("GBT parameter stepSize should be in interval (0, 1]") { withClue("GBT parameter stepSize should be in interval (0, 1]") { intercept[IllegalArgumentException] { @@ -246,7 +398,8 @@ private object GBTClassifierSuite extends SparkFunSuite { val newModel = gbt.fit(newData) // Use parent from newTree since this is not checked anyways. val oldModelAsNew = GBTClassificationModel.fromOld( - oldModel, newModel.parent.asInstanceOf[GBTClassifier], categoricalFeatures, numFeatures) + oldModel, newModel.parent.asInstanceOf[GBTClassifier], categoricalFeatures, + numFeatures, numClasses = 2) TreeTests.checkEqual(oldModelAsNew, newModel) assert(newModel.numFeatures === numFeatures) assert(oldModelAsNew.numFeatures === numFeatures)