From 08831e79ae22f53ebda20ae1fcc88ce4188681c9 Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Fri, 30 Dec 2016 15:15:12 -0500 Subject: [PATCH 01/20] [SPARK-14975][ML][WIP] Fixed GBTClassifier to predict probability per training instance and fixed interfaces --- .../DecisionTreeClassifier.scala | 2 + .../ml/classification/GBTClassifier.scala | 72 ++++++++++++++----- .../classification/GBTClassifierSuite.scala | 5 +- 3 files changed, 60 insertions(+), 19 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 9f60f0896ec5..39a959781df2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -177,6 +177,8 @@ class DecisionTreeClassificationModel private[ml] ( /** * Construct a decision tree classification model. * @param rootNode Root node of tree, with other nodes attached. + * @param numFeatures The number of features. + * @param numClasses The number of classes to predict. */ private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) = this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index c9bbd37a6736..7ffbf7b6510d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -20,12 +20,10 @@ package org.apache.spark.ml.classification import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.json4s.{DefaultFormats, JObject} import org.json4s.JsonDSL._ - import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging -import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.feature.LabeledPoint -import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree._ @@ -58,7 +56,7 @@ import org.apache.spark.sql.functions._ @Since("1.4.0") class GBTClassifier @Since("1.4.0") ( @Since("1.4.0") override val uid: String) - extends Predictor[Vector, GBTClassifier, GBTClassificationModel] + extends ProbabilisticClassifier[Vector, GBTClassifier, GBTClassificationModel] with GBTClassifierParams with DefaultParamsWritable with Logging { @Since("1.4.0") @@ -158,6 +156,13 @@ class GBTClassifier @Since("1.4.0") ( val numFeatures = oldDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) + val numClasses: Int = getNumClasses(dataset) + if (isDefined(thresholds)) { + require($(thresholds).length == numClasses, this.getClass.getSimpleName + + ".train() called with non-matching numClasses and thresholds.length." + + s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") + } + val instr = Instrumentation.create(this, oldDataset) instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, @@ -167,7 +172,7 @@ class GBTClassifier @Since("1.4.0") ( val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, $(seed)) - val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures) + val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures, numClasses) instr.logSuccess(m) m } @@ -202,8 +207,9 @@ class GBTClassificationModel private[ml]( @Since("1.6.0") override val uid: String, private val _trees: Array[DecisionTreeRegressionModel], private val _treeWeights: Array[Double], - @Since("1.6.0") override val numFeatures: Int) - extends PredictionModel[Vector, GBTClassificationModel] + @Since("1.6.0") override val numFeatures: Int, + @Since("2.2.0") override val numClasses: Int) + extends ProbabilisticClassificationModel[Vector, GBTClassificationModel] with GBTClassifierParams with TreeEnsembleModel[DecisionTreeRegressionModel] with MLWritable with Serializable { @@ -218,8 +224,9 @@ class GBTClassificationModel private[ml]( * @param _treeWeights Weights for the decision trees in the ensemble. */ @Since("1.6.0") - def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) = - this(uid, _trees, _treeWeights, -1) + def this(uid: String, _trees: Array[DecisionTreeRegressionModel], + _treeWeights: Array[Double]) = + this(uid, _trees, _treeWeights, -1, 2) @Since("1.4.0") override def trees: Array[DecisionTreeRegressionModel] = _trees @@ -249,12 +256,35 @@ class GBTClassificationModel private[ml]( if (prediction > 0.0) 1.0 else 0.0 } + override protected def predictRaw(features: Vector): Vector = { + val treeProbabilities = _trees + .map(_.rootNode.predictImpl(features).impurityStats.stats.clone()) + val weightedVectors = treeProbabilities.zipWithIndex + .map(zipped => zipped._1.map(value => value * _treeWeights.apply(zipped._2))) + // Return the averaged weighted vector + Vectors.dense(weightedVectors.reduce((a, b) => (a, b) + .zipped + .map(_ + _)) + .map(value => value / numTrees)) + } + + override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { + rawPrediction match { + case dv: DenseVector => + ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(dv) + dv + case sv: SparseVector => + throw new RuntimeException("Unexpected error in GBTClassificationModel:" + + " raw2probabilityInPlace encountered SparseVector") + } + } + /** Number of trees in ensemble */ val numTrees: Int = trees.length @Since("1.4.0") override def copy(extra: ParamMap): GBTClassificationModel = { - copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures), + copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures, numClasses), extra).setParent(parent) } @@ -288,6 +318,10 @@ class GBTClassificationModel private[ml]( @Since("2.0.0") object GBTClassificationModel extends MLReadable[GBTClassificationModel] { + private val numFeaturesKey: String = "numFeatures" + private val numTreesKey: String = "numTrees" + private val numClassesKey: String = "numClasses" + @Since("2.0.0") override def read: MLReader[GBTClassificationModel] = new GBTClassificationModelReader @@ -300,8 +334,9 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { override protected def saveImpl(path: String): Unit = { val extraMetadata: JObject = Map( - "numFeatures" -> instance.numFeatures, - "numTrees" -> instance.getNumTrees) + numFeaturesKey -> instance.numFeatures, + numTreesKey -> instance.getNumTrees, + numClassesKey -> instance.numClasses) EnsembleModelReadWrite.saveImpl(instance, path, sparkSession, extraMetadata) } } @@ -316,8 +351,9 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { implicit val format = DefaultFormats val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) - val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] - val numTrees = (metadata.metadata \ "numTrees").extract[Int] + val numFeatures = (metadata.metadata \ numFeaturesKey).extract[Int] + val numTrees = (metadata.metadata \ numTreesKey).extract[Int] + val numClasses = (metadata.metadata \ numClassesKey).extract[Int] val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => @@ -328,7 +364,8 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { } require(numTrees == trees.length, s"GBTClassificationModel.load expected $numTrees" + s" trees based on metadata but found ${trees.length} trees.") - val model = new GBTClassificationModel(metadata.uid, trees, treeWeights, numFeatures) + val model = new GBTClassificationModel(metadata.uid, + trees, treeWeights, numFeatures, numClasses) DefaultParamsReader.getAndSetParams(model, metadata) model } @@ -339,7 +376,8 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { oldModel: OldGBTModel, parent: GBTClassifier, categoricalFeatures: Map[Int, Int], - numFeatures: Int = -1): GBTClassificationModel = { + numFeatures: Int = -1, + numClasses: Int = 2): GBTClassificationModel = { require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" + s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).") val newTrees = oldModel.trees.map { tree => @@ -347,6 +385,6 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc") - new GBTClassificationModel(uid, newTrees, oldModel.treeWeights, numFeatures) + new GBTClassificationModel(uid, newTrees, oldModel.treeWeights, numFeatures, numClasses) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 7c36745ab213..9fe873a7a3ec 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -66,7 +66,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext ParamsSuite.checkParams(new GBTClassifier) val model = new GBTClassificationModel("gbtc", Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null), 1)), - Array(1.0), 1) + Array(1.0), 1, 2) ParamsSuite.checkParams(model) } @@ -246,7 +246,8 @@ private object GBTClassifierSuite extends SparkFunSuite { val newModel = gbt.fit(newData) // Use parent from newTree since this is not checked anyways. val oldModelAsNew = GBTClassificationModel.fromOld( - oldModel, newModel.parent.asInstanceOf[GBTClassifier], categoricalFeatures, numFeatures) + oldModel, newModel.parent.asInstanceOf[GBTClassifier], categoricalFeatures, + numFeatures, numClasses = 2) TreeTests.checkEqual(oldModelAsNew, newModel) assert(newModel.numFeatures === numFeatures) assert(oldModelAsNew.numFeatures === numFeatures) From e73b60f93badf2c3ca023cdec3bb520a72e6a2fd Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Fri, 30 Dec 2016 15:20:43 -0500 Subject: [PATCH 02/20] Fixed scala style empty line --- .../scala/org/apache/spark/ml/classification/GBTClassifier.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 7ffbf7b6510d..f916591f701f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.classification import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.json4s.{DefaultFormats, JObject} import org.json4s.JsonDSL._ + import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.LabeledPoint From d29b70df4e131b12bd79acff4a38138ddcdcc10c Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Fri, 30 Dec 2016 16:47:11 -0500 Subject: [PATCH 03/20] Fixed binary compatibility tests --- .../spark/ml/classification/GBTClassifier.scala | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index f916591f701f..f8b3e9e5cfcd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -200,6 +200,8 @@ object GBTClassifier extends DefaultParamsReadable[GBTClassifier] { * * @param _trees Decision trees in the ensemble. * @param _treeWeights Weights for the decision trees in the ensemble. + * @param numFeatures The number of features. + * @param numClasses The number of classes. * * @note Multiclass labels are not currently supported. */ @@ -218,6 +220,18 @@ class GBTClassificationModel private[ml]( require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" + s" of non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).") + /** + * Construct a GBTClassificationModel + * + * @param _trees Decision trees in the ensemble. + * @param _treeWeights Weights for the decision trees in the ensemble. + * @param numFeatures The number of features. + */ + @Since("1.6.0") + def this(uid: String, _trees: Array[DecisionTreeRegressionModel], + _treeWeights: Array[Double], numFeatures: Int) = + this(uid, _trees, _treeWeights, numFeatures, 2) + /** * Construct a GBTClassificationModel * @@ -227,7 +241,7 @@ class GBTClassificationModel private[ml]( @Since("1.6.0") def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) = - this(uid, _trees, _treeWeights, -1, 2) + this(uid, _trees, _treeWeights, -1, 2) @Since("1.4.0") override def trees: Array[DecisionTreeRegressionModel] = _trees From d4afdd0c17deb15298eb6a893526cc37a9abf20e Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Tue, 3 Jan 2017 15:50:26 -0500 Subject: [PATCH 04/20] Fixing GBT classifier based on comments --- .../ml/classification/GBTClassifier.scala | 25 +++++++++++-------- .../classification/GBTClassifierSuite.scala | 17 +++++++++++++ 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index f8b3e9e5cfcd..7a60f5736506 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -20,7 +20,6 @@ package org.apache.spark.ml.classification import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.json4s.{DefaultFormats, JObject} import org.json4s.JsonDSL._ - import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.LabeledPoint @@ -33,6 +32,7 @@ import org.apache.spark.ml.util._ import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} +import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ @@ -272,21 +272,24 @@ class GBTClassificationModel private[ml]( } override protected def predictRaw(features: Vector): Vector = { - val treeProbabilities = _trees - .map(_.rootNode.predictImpl(features).impurityStats.stats.clone()) - val weightedVectors = treeProbabilities.zipWithIndex - .map(zipped => zipped._1.map(value => value * _treeWeights.apply(zipped._2))) - // Return the averaged weighted vector - Vectors.dense(weightedVectors.reduce((a, b) => (a, b) - .zipped - .map(_ + _)) - .map(value => value / numTrees)) + val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction) + val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) + Vectors.dense(Array(-prediction, prediction)) } override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { rawPrediction match { + // The probability can be calculated for positive result: + // p+(x) = 1 / (1 + e^(-2*F(x))) + // and negative result: + // p-(x) = 1 / (1 + e^(2*F(x))) case dv: DenseVector => - ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(dv) + var i = 0 + val size = dv.size + while (i < size) { + dv.values(i) = 1 / MLUtils.log1pExp(-dv.values(i)) + i += 1 + } dv case sv: SparseVector => throw new RuntimeException("Unexpected error in GBTClassificationModel:" + diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 9fe873a7a3ec..73bd0945c51c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -70,6 +70,23 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext ParamsSuite.checkParams(model) } + test("Verify predicted probabilities correspond to labels") { + val rawPredictionCol = "MyRawPrediction" + val predictionCol = "MyPrediction" + val gbt = new GBTClassifier() + .setMaxDepth(2) + .setLossType("logistic") + .setMaxIter(5) + .setStepSize(0.1) + .setCheckpointInterval(2) + .setSeed(123) + .setRawPredictionCol(rawPredictionCol) + .setPredictionCol(predictionCol) + val gbtModel = gbt.fit(trainData.toDF()) + val scoredData = gbtModel.transform(validationData.toDF()) + scoredData.select(rawPredictionCol, predictionCol).foreach(row => print(row(0))) + } + test("GBT parameter stepSize should be in interval (0, 1]") { withClue("GBT parameter stepSize should be in interval (0, 1]") { intercept[IllegalArgumentException] { From 62702c8df47a045dfc408500f599426784d4a850 Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Thu, 5 Jan 2017 14:15:57 -0500 Subject: [PATCH 05/20] Fixing probabilities calculated from raw scores --- .../ml/classification/GBTClassifier.scala | 6 ++--- .../classification/GBTClassifierSuite.scala | 24 ++++++++++++++----- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 7a60f5736506..40d148b3c166 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -280,14 +280,14 @@ class GBTClassificationModel private[ml]( override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { rawPrediction match { // The probability can be calculated for positive result: - // p+(x) = 1 / (1 + e^(-2*F(x))) + // p+(x) = 1 / (1 + e^(-F(x))) // and negative result: - // p-(x) = 1 / (1 + e^(2*F(x))) + // p-(x) = 1 / (1 + e^(F(x))) case dv: DenseVector => var i = 0 val size = dv.size while (i < size) { - dv.values(i) = 1 / MLUtils.log1pExp(-dv.values(i)) + dv.values(i) = 1 / (1 + math.exp(-dv.values(i))) i += 1 } dv diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 73bd0945c51c..08e2a4d94d34 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.classification import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.feature.LabeledPoint -import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree.LeafNode @@ -28,7 +28,7 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.{MLUtils, MLlibTestSparkContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.util.Utils @@ -70,9 +70,11 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext ParamsSuite.checkParams(model) } - test("Verify predicted probabilities correspond to labels") { + test("Verify raw scores correspond to labels") { val rawPredictionCol = "MyRawPrediction" val predictionCol = "MyPrediction" + val labelCol = "label" + val featuresCol = "features" val gbt = new GBTClassifier() .setMaxDepth(2) .setLossType("logistic") @@ -82,9 +84,19 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext .setSeed(123) .setRawPredictionCol(rawPredictionCol) .setPredictionCol(predictionCol) - val gbtModel = gbt.fit(trainData.toDF()) - val scoredData = gbtModel.transform(validationData.toDF()) - scoredData.select(rawPredictionCol, predictionCol).foreach(row => print(row(0))) + .setLabelCol(labelCol) + .setFeaturesCol(featuresCol) + val gbtModel = gbt.fit(trainData.toDF(labelCol, featuresCol)) + val scoredData = gbtModel.transform(validationData.toDF(labelCol, featuresCol)) + scoredData.select(rawPredictionCol, predictionCol).collect() + .foreach(row => { + val probabilities = Vectors.dense(row(0).asInstanceOf[DenseVector] + .values.map(value => 1 / (1 + math.exp(-value)))) + // Verify probabilities make sense + assert(probabilities.toDense.values.forall(prob => prob <= 1 && prob >= 0)) + // Verify probabilities correspond to labels + assert(probabilities.argmax == row(1)) + }) } test("GBT parameter stepSize should be in interval (0, 1]") { From 27882b347e15d3cdcd07704cc63d5b2b9774f7fe Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Thu, 5 Jan 2017 14:41:33 -0500 Subject: [PATCH 06/20] fixed scala style, multiplied raw prediction value by 2 in prob estimate --- .../apache/spark/ml/classification/GBTClassifier.scala | 8 ++++---- .../spark/ml/classification/GBTClassifierSuite.scala | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 40d148b3c166..70ecc44e227b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.classification import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.json4s.{DefaultFormats, JObject} import org.json4s.JsonDSL._ + import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.LabeledPoint @@ -32,7 +33,6 @@ import org.apache.spark.ml.util._ import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} -import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ @@ -280,14 +280,14 @@ class GBTClassificationModel private[ml]( override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { rawPrediction match { // The probability can be calculated for positive result: - // p+(x) = 1 / (1 + e^(-F(x))) + // p+(x) = 1 / (1 + e^(-2 * F(x))) // and negative result: - // p-(x) = 1 / (1 + e^(F(x))) + // p-(x) = 1 / (1 + e^(2 * F(x))) case dv: DenseVector => var i = 0 val size = dv.size while (i < size) { - dv.values(i) = 1 / (1 + math.exp(-dv.values(i))) + dv.values(i) = 1 / (1 + math.exp(-2 * dv.values(i))) i += 1 } dv diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 08e2a4d94d34..e28bc713b276 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.classification import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.feature.LabeledPoint -import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors} +import org.apache.spark.ml.linalg.{DenseVector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree.LeafNode @@ -28,7 +28,7 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} -import org.apache.spark.mllib.util.{MLUtils, MLlibTestSparkContext} +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.util.Utils @@ -91,7 +91,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext scoredData.select(rawPredictionCol, predictionCol).collect() .foreach(row => { val probabilities = Vectors.dense(row(0).asInstanceOf[DenseVector] - .values.map(value => 1 / (1 + math.exp(-value)))) + .values.map(value => 1 / (1 + math.exp(-2 * value)))) // Verify probabilities make sense assert(probabilities.toDense.values.forall(prob => prob <= 1 && prob >= 0)) // Verify probabilities correspond to labels From 8698d1618d26d255f85bf63875a292cb8e4bb0ab Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Thu, 5 Jan 2017 23:47:40 -0500 Subject: [PATCH 07/20] Updating based on code review, including code cleanup and adding better test case --- .../ml/classification/GBTClassifier.scala | 24 ++--- .../classification/GBTClassifierSuite.scala | 98 ++++++++++++++----- 2 files changed, 80 insertions(+), 42 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 70ecc44e227b..ee48b14478b1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -225,22 +225,9 @@ class GBTClassificationModel private[ml]( * * @param _trees Decision trees in the ensemble. * @param _treeWeights Weights for the decision trees in the ensemble. - * @param numFeatures The number of features. */ @Since("1.6.0") - def this(uid: String, _trees: Array[DecisionTreeRegressionModel], - _treeWeights: Array[Double], numFeatures: Int) = - this(uid, _trees, _treeWeights, numFeatures, 2) - - /** - * Construct a GBTClassificationModel - * - * @param _trees Decision trees in the ensemble. - * @param _treeWeights Weights for the decision trees in the ensemble. - */ - @Since("1.6.0") - def this(uid: String, _trees: Array[DecisionTreeRegressionModel], - _treeWeights: Array[Double]) = + def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) = this(uid, _trees, _treeWeights, -1, 2) @Since("1.4.0") @@ -287,7 +274,7 @@ class GBTClassificationModel private[ml]( var i = 0 val size = dv.size while (i < size) { - dv.values(i) = 1 / (1 + math.exp(-2 * dv.values(i))) + dv.values(i) = classProbability(getLossType, dv.values(i)) i += 1 } dv @@ -324,6 +311,13 @@ class GBTClassificationModel private[ml]( @Since("2.0.0") lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures) + private def classProbability(loss: String, rawPrediction: Double): Double = { + loss match { + case "logistic" => 1 / (1 + math.exp(-2 * rawPrediction)) + case _ => throw new Exception("Only logistic loss is supported ...") + } + } + /** (private[ml]) Convert to a model in the old API */ private[ml] def toOld: OldGBTModel = { new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index e28bc713b276..da6ea7a46216 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -17,20 +17,23 @@ package org.apache.spark.ml.classification +import com.github.fommil.netlib.BLAS + import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.feature.LabeledPoint -import org.apache.spark.ml.linalg.{DenseVector, Vectors} +import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree.LeafNode import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.util.Utils /** @@ -49,6 +52,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext private var data: RDD[LabeledPoint] = _ private var trainData: RDD[LabeledPoint] = _ private var validationData: RDD[LabeledPoint] = _ + private val eps: Double = 1e-5 override def beforeAll() { super.beforeAll() @@ -70,33 +74,73 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext ParamsSuite.checkParams(model) } - test("Verify raw scores correspond to labels") { - val rawPredictionCol = "MyRawPrediction" - val predictionCol = "MyPrediction" + test("GBTClassifier: Predictor, Classifier methods") { + val rawPredictionCol = "rawPrediction" + val predictionCol = "prediction" val labelCol = "label" val featuresCol = "features" - val gbt = new GBTClassifier() - .setMaxDepth(2) - .setLossType("logistic") - .setMaxIter(5) - .setStepSize(0.1) - .setCheckpointInterval(2) - .setSeed(123) - .setRawPredictionCol(rawPredictionCol) - .setPredictionCol(predictionCol) - .setLabelCol(labelCol) - .setFeaturesCol(featuresCol) - val gbtModel = gbt.fit(trainData.toDF(labelCol, featuresCol)) - val scoredData = gbtModel.transform(validationData.toDF(labelCol, featuresCol)) - scoredData.select(rawPredictionCol, predictionCol).collect() - .foreach(row => { - val probabilities = Vectors.dense(row(0).asInstanceOf[DenseVector] - .values.map(value => 1 / (1 + math.exp(-2 * value)))) - // Verify probabilities make sense - assert(probabilities.toDense.values.forall(prob => prob <= 1 && prob >= 0)) - // Verify probabilities correspond to labels - assert(probabilities.argmax == row(1)) - }) + val probabilityCol = "probability" + + val gbt = new GBTClassifier().setSeed(123) + val trainingDataset = trainData.toDF(labelCol, featuresCol) + val gbtModel = gbt.fit(trainingDataset) + assert(gbtModel.numClasses === 2) + val numFeatures = trainingDataset.select(featuresCol).first().getAs[Vector](0).size + assert(gbtModel.numFeatures === numFeatures) + + val blas = BLAS.getInstance() + + val validationDataset = validationData.toDF(labelCol, featuresCol) + val results = gbtModel.transform(validationDataset) + // check that raw prediction is tree predictions dot tree weights + results.select(rawPredictionCol, featuresCol).collect().foreach { + case Row(raw: Vector, features: Vector) => + assert(raw.size === 2) + val treePredictions = gbtModel.trees.map(_.rootNode.predictImpl(features).prediction) + val prediction = blas.ddot(gbtModel.numTrees, treePredictions, 1, gbtModel.treeWeights, 1) + assert(raw ~== Vectors.dense(-prediction, prediction) relTol eps) + } + + // Compare rawPrediction with probability + results.select(rawPredictionCol, probabilityCol).collect().foreach { + case Row(raw: Vector, prob: Vector) => + assert(raw.size === 2) + assert(prob.size === 2) + val prodFromRaw = raw.toDense.values.map(value => 1 / (1 + math.exp(-2 * value))) + assert(prob(0) ~== prodFromRaw(0) relTol eps) + assert(prob(1) ~== prodFromRaw(1) relTol eps) + } + + // Compare prediction with probability + results.select(predictionCol, probabilityCol).collect().foreach { + case Row(pred: Double, prob: Vector) => + val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2 + assert(pred == predFromProb) + } + + // force it to use raw2prediction + gbtModel.setRawPredictionCol(rawPredictionCol).setProbabilityCol("") + val resultsUsingRaw2Predict = + gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect() + resultsUsingRaw2Predict.zip(results.select(predictionCol).as[Double].collect()).foreach { + case (pred1, pred2) => assert(pred1 === pred2) + } + + // force it to use probability2prediction + gbtModel.setRawPredictionCol("").setProbabilityCol(probabilityCol) + val resultsUsingProb2Predict = + gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect() + resultsUsingProb2Predict.zip(results.select(predictionCol).as[Double].collect()).foreach { + case (pred1, pred2) => assert(pred1 === pred2) + } + + // force it to use predict + gbtModel.setRawPredictionCol("").setProbabilityCol("") + val resultsUsingPredict = + gbtModel.transform(validationDataset).select(predictionCol).as[Double].collect() + resultsUsingPredict.zip(results.select(predictionCol).as[Double].collect()).foreach { + case (pred1, pred2) => assert(pred1 === pred2) + } } test("GBT parameter stepSize should be in interval (0, 1]") { From aaf1b068e23430fbcba1b0a181eec60626869791 Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Fri, 6 Jan 2017 10:20:53 -0500 Subject: [PATCH 08/20] Adding back constructor but making it private --- .../spark/ml/classification/GBTClassifier.scala | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index ee48b14478b1..5f95b8525192 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -220,6 +220,18 @@ class GBTClassificationModel private[ml]( require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" + s" of non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).") + /** + * Construct a GBTClassificationModel + * + * @param _trees Decision trees in the ensemble. + * @param _treeWeights Weights for the decision trees in the ensemble. + * @param numFeatures The number of features. + */ + @Since("1.6.0") + private[ml] def this(uid: String, _trees: Array[DecisionTreeRegressionModel], + _treeWeights: Array[Double], numFeatures: Int) = + this(uid, _trees, _treeWeights, numFeatures, 2) + /** * Construct a GBTClassificationModel * From bafab79fbf5fe535a0b5d5a73d0b2f6cc017e4b3 Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Tue, 10 Jan 2017 13:31:45 -0500 Subject: [PATCH 09/20] updates to GBTClassifier based on comments --- .../ml/classification/GBTClassifier.scala | 26 +++++------- .../spark/mllib/tree/loss/LogLoss.scala | 15 ++++++- .../classification/GBTClassifierSuite.scala | 40 ++++++++++++++++--- 3 files changed, 59 insertions(+), 22 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 5f95b8525192..16086b5bbe5e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -20,7 +20,6 @@ package org.apache.spark.ml.classification import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.json4s.{DefaultFormats, JObject} import org.json4s.JsonDSL._ - import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.LabeledPoint @@ -32,6 +31,7 @@ import org.apache.spark.ml.tree.impl.GradientBoostedTrees import org.apache.spark.ml.util._ import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.tree.loss.{ClassificationLoss, LogLoss} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} @@ -227,7 +227,6 @@ class GBTClassificationModel private[ml]( * @param _treeWeights Weights for the decision trees in the ensemble. * @param numFeatures The number of features. */ - @Since("1.6.0") private[ml] def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double], numFeatures: Int) = this(uid, _trees, _treeWeights, numFeatures, 2) @@ -263,16 +262,12 @@ class GBTClassificationModel private[ml]( } override protected def predict(features: Vector): Double = { - // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 - // Classifies by thresholding sum of weighted tree predictions - val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction) - val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) + val prediction: Double = margin(features) if (prediction > 0.0) 1.0 else 0.0 } override protected def predictRaw(features: Vector): Vector = { - val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction) - val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) + val prediction: Double = margin(features) Vectors.dense(Array(-prediction, prediction)) } @@ -283,12 +278,8 @@ class GBTClassificationModel private[ml]( // and negative result: // p-(x) = 1 / (1 + e^(2 * F(x))) case dv: DenseVector => - var i = 0 - val size = dv.size - while (i < size) { - dv.values(i) = classProbability(getLossType, dv.values(i)) - i += 1 - } + dv.values(0) = classProbability(getLossType, dv.values(0)) + dv.values(1) = 1.0 - dv.values(0) dv case sv: SparseVector => throw new RuntimeException("Unexpected error in GBTClassificationModel:" + @@ -325,11 +316,16 @@ class GBTClassificationModel private[ml]( private def classProbability(loss: String, rawPrediction: Double): Double = { loss match { - case "logistic" => 1 / (1 + math.exp(-2 * rawPrediction)) + case "logistic" => LogLoss.computeProbability(rawPrediction) case _ => throw new Exception("Only logistic loss is supported ...") } } + private def margin(features: Vector): Double = { + val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction) + blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) + } + /** (private[ml]) Convert to a model in the old API */ private[ml] def toOld: OldGBTModel = { new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala index 5d92ce495b04..7a1fb11983d0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala @@ -20,6 +20,15 @@ package org.apache.spark.mllib.tree.loss import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.util.MLUtils +/** + * :: DeveloperApi :: + * Trait for adding "pluggable" probability function for the gradient boosting algorithm. + */ +@Since("1.2.0") +@DeveloperApi +trait ClassificationLoss extends Loss { + private[spark] def computeProbability(prediction: Double): Double +} /** * :: DeveloperApi :: @@ -32,7 +41,7 @@ import org.apache.spark.mllib.util.MLUtils */ @Since("1.2.0") @DeveloperApi -object LogLoss extends Loss { +object LogLoss extends ClassificationLoss { /** * Method to calculate the loss gradients for the gradient boosting calculation for binary @@ -52,4 +61,8 @@ object LogLoss extends Loss { // The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable. 2.0 * MLUtils.log1pExp(-margin) } + + override private[spark] def computeProbability(prediction: Double): Double = { + 1 / (1 + math.exp(-2 * prediction)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index da6ea7a46216..2195dc2e2b70 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -18,11 +18,10 @@ package org.apache.spark.ml.classification import com.github.fommil.netlib.BLAS - import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.feature.LabeledPoint -import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors} -import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.param.{ParamMap, ParamsSuite} import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree.LeafNode import org.apache.spark.ml.tree.impl.TreeTests @@ -31,6 +30,7 @@ import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.tree.loss.LogLoss import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} @@ -53,6 +53,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext private var trainData: RDD[LabeledPoint] = _ private var validationData: RDD[LabeledPoint] = _ private val eps: Double = 1e-5 + private val absEps: Double = 1e-8 override def beforeAll() { super.beforeAll() @@ -74,6 +75,31 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext ParamsSuite.checkParams(model) } + test("GBTClassifier: default params") { + val gbt = new GBTClassifier + assert(gbt.getLabelCol === "label") + assert(gbt.getFeaturesCol === "features") + assert(gbt.getPredictionCol === "prediction") + assert(gbt.getRawPredictionCol === "rawPrediction") + assert(gbt.getProbabilityCol === "probability") + val df = trainData.toDF() + val model = gbt.fit(df) + model.transform(df) + .select("label", "probability", "prediction", "rawPrediction") + .collect() + intercept[NoSuchElementException] { + model.getThresholds + } + assert(model.getFeaturesCol === "features") + assert(model.getPredictionCol === "prediction") + assert(model.getRawPredictionCol === "rawPrediction") + assert(model.getProbabilityCol === "probability") + assert(model.hasParent) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + } + test("GBTClassifier: Predictor, Classifier methods") { val rawPredictionCol = "rawPrediction" val predictionCol = "prediction" @@ -106,9 +132,11 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext case Row(raw: Vector, prob: Vector) => assert(raw.size === 2) assert(prob.size === 2) - val prodFromRaw = raw.toDense.values.map(value => 1 / (1 + math.exp(-2 * value))) - assert(prob(0) ~== prodFromRaw(0) relTol eps) - assert(prob(1) ~== prodFromRaw(1) relTol eps) + // Note: we should check other loss types for classification if they are added + val predFromRaw = raw.toDense.values.map(value => LogLoss.computeProbability(value)) + assert(prob(0) ~== predFromRaw(0) relTol eps) + assert(prob(1) ~== predFromRaw(1) relTol eps) + assert(prob(0) + prob(1) ~== 1.0 absTol absEps) } // Compare prediction with probability From 2a6dea431aee1de17610149dbb19fb7f8d6a8b4a Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Tue, 10 Jan 2017 13:34:58 -0500 Subject: [PATCH 10/20] minor fixes to scala style --- .../org/apache/spark/ml/classification/GBTClassifier.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 16086b5bbe5e..0dcd5b85012f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.classification import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.json4s.{DefaultFormats, JObject} import org.json4s.JsonDSL._ + import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.LabeledPoint @@ -31,7 +32,7 @@ import org.apache.spark.ml.tree.impl.GradientBoostedTrees import org.apache.spark.ml.util._ import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} -import org.apache.spark.mllib.tree.loss.{ClassificationLoss, LogLoss} +import org.apache.spark.mllib.tree.loss.LogLoss import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} From 52c511569dc674051c8e57491966d000713a5f18 Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Tue, 10 Jan 2017 13:45:10 -0500 Subject: [PATCH 11/20] Fixing more scala style --- .../apache/spark/ml/classification/GBTClassifierSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 2195dc2e2b70..4bf2d6a1e949 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -18,10 +18,11 @@ package org.apache.spark.ml.classification import com.github.fommil.netlib.BLAS + import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.param.{ParamMap, ParamsSuite} +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree.LeafNode import org.apache.spark.ml.tree.impl.TreeTests From 609a1b0a29c9835f33196ec56736736e60acf232 Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Tue, 10 Jan 2017 14:06:19 -0500 Subject: [PATCH 12/20] Using getOldLossType as per comments --- .../apache/spark/ml/classification/GBTClassifier.scala | 9 +-------- .../main/scala/org/apache/spark/ml/tree/treeParams.scala | 4 ++-- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 0dcd5b85012f..0c3ceaf04e04 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -279,7 +279,7 @@ class GBTClassificationModel private[ml]( // and negative result: // p-(x) = 1 / (1 + e^(2 * F(x))) case dv: DenseVector => - dv.values(0) = classProbability(getLossType, dv.values(0)) + dv.values(0) = getOldLossType.computeProbability(dv.values(0)) dv.values(1) = 1.0 - dv.values(0) dv case sv: SparseVector => @@ -315,13 +315,6 @@ class GBTClassificationModel private[ml]( @Since("2.0.0") lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures) - private def classProbability(loss: String, rawPrediction: Double): Double = { - loss match { - case "logistic" => LogLoss.computeProbability(rawPrediction) - case _ => throw new Exception("Only logistic loss is supported ...") - } - } - private def margin(features: Vector): Double = { val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction) blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index c7a8f76eca84..5eb707dfe7bc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -25,7 +25,7 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} -import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError} +import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, ClassificationLoss => OldClassificationLoss, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError} import org.apache.spark.sql.types.{DataType, DoubleType, StructType} /** @@ -531,7 +531,7 @@ private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParam def getLossType: String = $(lossType).toLowerCase /** (private[ml]) Convert new loss to old loss. */ - override private[ml] def getOldLossType: OldLoss = { + override private[ml] def getOldLossType: OldClassificationLoss = { getLossType match { case "logistic" => OldLogLoss case _ => From a28afe60219dba63bfb5df2d32567b7a7ef91594 Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Tue, 10 Jan 2017 14:46:42 -0500 Subject: [PATCH 13/20] Added more tests for thresholds, fixed minor bug in predict to use thresholds if they are specified --- .../ml/classification/GBTClassifier.scala | 9 +++- .../classification/GBTClassifierSuite.scala | 51 +++++++++++++++++++ 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 0c3ceaf04e04..36c69701bab7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -263,8 +263,13 @@ class GBTClassificationModel private[ml]( } override protected def predict(features: Vector): Double = { - val prediction: Double = margin(features) - if (prediction > 0.0) 1.0 else 0.0 + // If thresholds defined, use predictRaw to get probabilities, otherwise use optimization + if (isDefined(thresholds)) { + super.predict(features) + } else { + val prediction: Double = margin(features) + if (prediction > 0.0) 1.0 else 0.0 + } } override protected def predictRaw(features: Vector): Vector = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 4bf2d6a1e949..43e2ef4d2d37 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -101,6 +101,57 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext MLTestingUtils.checkCopy(model) } + test("setThreshold, getThreshold") { + val gbt = new GBTClassifier + + // default + withClue("GBTClassifier should not have thresholds set by default.") { + intercept[NoSuchElementException] { + gbt.getThresholds + } + } + + // Set via thresholds + val gbt2 = new GBTClassifier + val threshold = Array(0.3, 0.7) + gbt2.setThresholds(threshold) + assert(gbt2.getThresholds.zipWithIndex.forall(valueWithIndex => + threshold(valueWithIndex._2) == valueWithIndex._1)) + } + + test("thresholds prediction") { + val gbt = new GBTClassifier + val df = trainData.toDF() + val binaryModel = gbt.fit(df) + + // should predict all zeros + binaryModel.setThresholds(Array(0.0, 1.0)) + val binaryZeroPredictions = binaryModel.transform(df).select("prediction").collect() + assert(binaryZeroPredictions.forall(_.getDouble(0) === 0.0)) + + // should predict all ones + binaryModel.setThresholds(Array(1.0, 0.0)) + val binaryOnePredictions = binaryModel.transform(df).select("prediction").collect() + assert(binaryOnePredictions.forall(_.getDouble(0) === 1.0)) + + + val gbtBase = new GBTClassifier + val model = gbtBase.fit(df) + val basePredictions = model.transform(df).select("prediction").collect() + + // constant threshold scaling is the same as no thresholds + binaryModel.setThresholds(Array(1.0, 1.0)) + val scaledPredictions = binaryModel.transform(df).select("prediction").collect() + assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) => + scaled.getDouble(0) === base.getDouble(0) + }) + + // force it to use the predict method + model.setRawPredictionCol("").setProbabilityCol("").setThresholds(Array(0, 1)) + val predictionsWithPredict = model.transform(df).select("prediction").collect() + assert(predictionsWithPredict.forall(_.getDouble(0) === 0.0)) + } + test("GBTClassifier: Predictor, Classifier methods") { val rawPredictionCol = "rawPrediction" val predictionCol = "prediction" From 9d5bb9b598903583c95b4de3142d23106c971e55 Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Tue, 10 Jan 2017 18:35:43 -0500 Subject: [PATCH 14/20] Updated based on newest comments --- .../ml/classification/GBTClassifier.scala | 24 ++++++++++--------- .../spark/mllib/tree/loss/LogLoss.scala | 11 ++++----- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 36c69701bab7..da51f5aade7a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -158,7 +158,7 @@ class GBTClassifier @Since("1.4.0") ( val numFeatures = oldDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) - val numClasses: Int = getNumClasses(dataset) + val numClasses: Int = 2 if (isDefined(thresholds)) { require($(thresholds).length == numClasses, this.getClass.getSimpleName + ".train() called with non-matching numClasses and thresholds.length." + @@ -229,8 +229,9 @@ class GBTClassificationModel private[ml]( * @param numFeatures The number of features. */ private[ml] def this(uid: String, _trees: Array[DecisionTreeRegressionModel], - _treeWeights: Array[Double], numFeatures: Int) = - this(uid, _trees, _treeWeights, numFeatures, 2) + _treeWeights: Array[Double], + numFeatures: Int) = + this(uid, _trees, _treeWeights, numFeatures, 2) /** * Construct a GBTClassificationModel @@ -240,7 +241,7 @@ class GBTClassificationModel private[ml]( */ @Since("1.6.0") def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) = - this(uid, _trees, _treeWeights, -1, 2) + this(uid, _trees, _treeWeights, -1, 2) @Since("1.4.0") override def trees: Array[DecisionTreeRegressionModel] = _trees @@ -267,8 +268,7 @@ class GBTClassificationModel private[ml]( if (isDefined(thresholds)) { super.predict(features) } else { - val prediction: Double = margin(features) - if (prediction > 0.0) 1.0 else 0.0 + if (margin(features) > 0.0) 1.0 else 0.0 } } @@ -279,12 +279,8 @@ class GBTClassificationModel private[ml]( override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { rawPrediction match { - // The probability can be calculated for positive result: - // p+(x) = 1 / (1 + e^(-2 * F(x))) - // and negative result: - // p-(x) = 1 / (1 + e^(2 * F(x))) case dv: DenseVector => - dv.values(0) = getOldLossType.computeProbability(dv.values(0)) + dv.values(0) = loss.computeProbability(dv.values(0)) dv.values(1) = 1.0 - dv.values(0) dv case sv: SparseVector => @@ -330,6 +326,12 @@ class GBTClassificationModel private[ml]( new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights) } + /** + * Note: this is currently an optimization that should be removed when we have more loss + * functions available than only logistic. + */ + private lazy val loss = getOldLossType + @Since("2.0.0") override def write: MLWriter = new GBTClassificationModel.GBTClassificationModelWriter(this) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala index 7a1fb11983d0..da06ef5311ca 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala @@ -21,12 +21,9 @@ import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.util.MLUtils /** - * :: DeveloperApi :: - * Trait for adding "pluggable" probability function for the gradient boosting algorithm. + * Trait for adding probability function for the gradient boosting algorithm. */ -@Since("1.2.0") -@DeveloperApi -trait ClassificationLoss extends Loss { +private[spark] trait ClassificationLoss extends Loss { private[spark] def computeProbability(prediction: Double): Double } @@ -63,6 +60,8 @@ object LogLoss extends ClassificationLoss { } override private[spark] def computeProbability(prediction: Double): Double = { - 1 / (1 + math.exp(-2 * prediction)) + // The probability can be calculated as: + // p+(x) = 1 / (1 + e^(-2 * F(x))) + 1.0 / (1.0 + math.exp(-2.0 * prediction)) } } From 89965f570d3d3bd93e4d1ed9c3d8a5f0c8eb06be Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Tue, 10 Jan 2017 18:42:08 -0500 Subject: [PATCH 15/20] missed one arg --- .../org/apache/spark/ml/classification/GBTClassifier.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index da51f5aade7a..5d1a3e8b9a3c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -228,7 +228,8 @@ class GBTClassificationModel private[ml]( * @param _treeWeights Weights for the decision trees in the ensemble. * @param numFeatures The number of features. */ - private[ml] def this(uid: String, _trees: Array[DecisionTreeRegressionModel], + private[ml] def this(uid: String, + _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double], numFeatures: Int) = this(uid, _trees, _treeWeights, numFeatures, 2) From cacbbc10849cb0ff62f993921076d96453669bba Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Tue, 10 Jan 2017 18:45:36 -0500 Subject: [PATCH 16/20] Moving arg to its own line --- .../org/apache/spark/ml/classification/GBTClassifier.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 5d1a3e8b9a3c..fcf73d471bc0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -228,7 +228,8 @@ class GBTClassificationModel private[ml]( * @param _treeWeights Weights for the decision trees in the ensemble. * @param numFeatures The number of features. */ - private[ml] def this(uid: String, + private[ml] def this( + uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double], numFeatures: Int) = From 7396dac3bab5f91dab4ccb4368bbbed02de53ed8 Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Wed, 11 Jan 2017 01:29:23 -0500 Subject: [PATCH 17/20] Updated based on latest comments - moved classifier loss trait, updated doc --- .../apache/spark/ml/classification/GBTClassifier.scala | 7 ++----- .../org/apache/spark/mllib/tree/loss/LogLoss.scala | 7 ------- .../scala/org/apache/spark/mllib/tree/loss/Loss.scala | 10 +++++++++- .../spark/ml/classification/GBTClassifierSuite.scala | 3 +-- 4 files changed, 12 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index fcf73d471bc0..874aa5e24355 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -318,6 +318,7 @@ class GBTClassificationModel private[ml]( @Since("2.0.0") lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures) + /** Raw prediction for the positive class. */ private def margin(features: Vector): Double = { val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction) blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) @@ -328,11 +329,7 @@ class GBTClassificationModel private[ml]( new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights) } - /** - * Note: this is currently an optimization that should be removed when we have more loss - * functions available than only logistic. - */ - private lazy val loss = getOldLossType + private val loss = getOldLossType @Since("2.0.0") override def write: MLWriter = new GBTClassificationModel.GBTClassificationModelWriter(this) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala index da06ef5311ca..1dc9b072e82a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala @@ -20,13 +20,6 @@ package org.apache.spark.mllib.tree.loss import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.util.MLUtils -/** - * Trait for adding probability function for the gradient boosting algorithm. - */ -private[spark] trait ClassificationLoss extends Loss { - private[spark] def computeProbability(prediction: Double): Double -} - /** * :: DeveloperApi :: * Class for log loss calculation (for classification). diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala index 09274a2e1b2a..9f87f066583b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala @@ -22,7 +22,6 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.rdd.RDD - /** * :: DeveloperApi :: * Trait for adding "pluggable" loss functions for the gradient boosting algorithm. @@ -67,3 +66,12 @@ trait Loss extends Serializable { */ private[spark] def computeError(prediction: Double, label: Double): Double } + +private[spark] trait ClassificationLoss extends Loss { + /** + * Computes the class probability given the margin. + * @param prediction The margin. + * @return The class probability from the margin. + */ + private[spark] def computeProbability(prediction: Double): Double +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 43e2ef4d2d37..b79fbcefef96 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -115,8 +115,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext val gbt2 = new GBTClassifier val threshold = Array(0.3, 0.7) gbt2.setThresholds(threshold) - assert(gbt2.getThresholds.zipWithIndex.forall(valueWithIndex => - threshold(valueWithIndex._2) == valueWithIndex._1)) + assert(gbt2.getThresholds.zip(threshold).forall { case(t1, t2) => t1 === t2 }) } test("thresholds prediction") { From f2e041defa64f3444848c6d2ebeaf35e30ed0f03 Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Wed, 11 Jan 2017 18:40:20 -0500 Subject: [PATCH 18/20] Fixed up minor comments --- .../org/apache/spark/ml/classification/GBTClassifier.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 874aa5e24355..d954ff89eff5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -158,7 +158,7 @@ class GBTClassifier @Since("1.4.0") ( val numFeatures = oldDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) - val numClasses: Int = 2 + val numClasses = 2 if (isDefined(thresholds)) { require($(thresholds).length == numClasses, this.getClass.getSimpleName + ".train() called with non-matching numClasses and thresholds.length." + @@ -201,8 +201,6 @@ object GBTClassifier extends DefaultParamsReadable[GBTClassifier] { * * @param _trees Decision trees in the ensemble. * @param _treeWeights Weights for the decision trees in the ensemble. - * @param numFeatures The number of features. - * @param numClasses The number of classes. * * @note Multiclass labels are not currently supported. */ @@ -329,6 +327,7 @@ class GBTClassificationModel private[ml]( new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights) } + // hard coded loss, which is not meant to be changed in the model private val loss = getOldLossType @Since("2.0.0") From 1abfee0a0a9bd8a09d38a8877af795bfc8732ccc Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Wed, 18 Jan 2017 14:59:23 -0500 Subject: [PATCH 19/20] Updated based on comments from jkbradley --- .../DecisionTreeClassifier.scala | 2 -- .../ml/classification/GBTClassifier.scala | 35 +++++-------------- .../spark/mllib/tree/loss/LogLoss.scala | 9 ++--- .../apache/spark/mllib/tree/loss/Loss.scala | 4 +-- .../classification/GBTClassifierSuite.scala | 2 +- 5 files changed, 15 insertions(+), 37 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 39a959781df2..9f60f0896ec5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -177,8 +177,6 @@ class DecisionTreeClassificationModel private[ml] ( /** * Construct a decision tree classification model. * @param rootNode Root node of tree, with other nodes attached. - * @param numFeatures The number of features. - * @param numClasses The number of classes to predict. */ private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) = this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index d954ff89eff5..5794eb26e416 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -170,7 +170,7 @@ class GBTClassifier @Since("1.4.0") ( maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval) instr.logNumFeatures(numFeatures) - instr.logNumClasses(2) + instr.logNumClasses(numClasses) val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, $(seed)) @@ -209,8 +209,7 @@ class GBTClassificationModel private[ml]( @Since("1.6.0") override val uid: String, private val _trees: Array[DecisionTreeRegressionModel], private val _treeWeights: Array[Double], - @Since("1.6.0") override val numFeatures: Int, - @Since("2.2.0") override val numClasses: Int) + @Since("1.6.0") override val numFeatures: Int) extends ProbabilisticClassificationModel[Vector, GBTClassificationModel] with GBTClassifierParams with TreeEnsembleModel[DecisionTreeRegressionModel] with MLWritable with Serializable { @@ -219,20 +218,6 @@ class GBTClassificationModel private[ml]( require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" + s" of non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).") - /** - * Construct a GBTClassificationModel - * - * @param _trees Decision trees in the ensemble. - * @param _treeWeights Weights for the decision trees in the ensemble. - * @param numFeatures The number of features. - */ - private[ml] def this( - uid: String, - _trees: Array[DecisionTreeRegressionModel], - _treeWeights: Array[Double], - numFeatures: Int) = - this(uid, _trees, _treeWeights, numFeatures, 2) - /** * Construct a GBTClassificationModel * @@ -241,7 +226,7 @@ class GBTClassificationModel private[ml]( */ @Since("1.6.0") def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) = - this(uid, _trees, _treeWeights, -1, 2) + this(uid, _trees, _treeWeights, -1) @Since("1.4.0") override def trees: Array[DecisionTreeRegressionModel] = _trees @@ -294,7 +279,7 @@ class GBTClassificationModel private[ml]( @Since("1.4.0") override def copy(extra: ParamMap): GBTClassificationModel = { - copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures, numClasses), + copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures), extra).setParent(parent) } @@ -339,7 +324,6 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { private val numFeaturesKey: String = "numFeatures" private val numTreesKey: String = "numTrees" - private val numClassesKey: String = "numClasses" @Since("2.0.0") override def read: MLReader[GBTClassificationModel] = new GBTClassificationModelReader @@ -354,8 +338,7 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { val extraMetadata: JObject = Map( numFeaturesKey -> instance.numFeatures, - numTreesKey -> instance.getNumTrees, - numClassesKey -> instance.numClasses) + numTreesKey -> instance.getNumTrees) EnsembleModelReadWrite.saveImpl(instance, path, sparkSession, extraMetadata) } } @@ -372,7 +355,6 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) val numFeatures = (metadata.metadata \ numFeaturesKey).extract[Int] val numTrees = (metadata.metadata \ numTreesKey).extract[Int] - val numClasses = (metadata.metadata \ numClassesKey).extract[Int] val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => @@ -384,7 +366,7 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { require(numTrees == trees.length, s"GBTClassificationModel.load expected $numTrees" + s" trees based on metadata but found ${trees.length} trees.") val model = new GBTClassificationModel(metadata.uid, - trees, treeWeights, numFeatures, numClasses) + trees, treeWeights, numFeatures) DefaultParamsReader.getAndSetParams(model, metadata) model } @@ -395,8 +377,7 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { oldModel: OldGBTModel, parent: GBTClassifier, categoricalFeatures: Map[Int, Int], - numFeatures: Int = -1, - numClasses: Int = 2): GBTClassificationModel = { + numFeatures: Int = -1): GBTClassificationModel = { require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" + s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).") val newTrees = oldModel.trees.map { tree => @@ -404,6 +385,6 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc") - new GBTClassificationModel(uid, newTrees, oldModel.treeWeights, numFeatures, numClasses) + new GBTClassificationModel(uid, newTrees, oldModel.treeWeights, numFeatures) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala index 1dc9b072e82a..9339f0a23c1b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala @@ -52,9 +52,10 @@ object LogLoss extends ClassificationLoss { 2.0 * MLUtils.log1pExp(-margin) } - override private[spark] def computeProbability(prediction: Double): Double = { - // The probability can be calculated as: - // p+(x) = 1 / (1 + e^(-2 * F(x))) - 1.0 / (1.0 + math.exp(-2.0 * prediction)) + /** + * Returns the estimated probability of a label of 1.0. + */ + override private[spark] def computeProbability(margin: Double): Double = { + 1.0 / (1.0 + math.exp(-2.0 * margin)) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala index 9f87f066583b..e7ffb3f8f53c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala @@ -70,8 +70,6 @@ trait Loss extends Serializable { private[spark] trait ClassificationLoss extends Loss { /** * Computes the class probability given the margin. - * @param prediction The margin. - * @return The class probability from the margin. */ - private[spark] def computeProbability(prediction: Double): Double + private[spark] def computeProbability(margin: Double): Double } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index b79fbcefef96..0598943c3d4b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -115,7 +115,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext val gbt2 = new GBTClassifier val threshold = Array(0.3, 0.7) gbt2.setThresholds(threshold) - assert(gbt2.getThresholds.zip(threshold).forall { case(t1, t2) => t1 === t2 }) + assert(gbt2.getThresholds === threshold) } test("thresholds prediction") { From 818de810cbfdf4fc671f21a262a34cc8554f9af6 Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Wed, 18 Jan 2017 16:35:22 -0500 Subject: [PATCH 20/20] Fixing build issues - need to keep numClasses in model --- .../ml/classification/GBTClassifier.scala | 28 +++++++++++++++---- 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 5794eb26e416..ade0960f87a0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -174,7 +174,7 @@ class GBTClassifier @Since("1.4.0") ( val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, $(seed)) - val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures, numClasses) + val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures) instr.logSuccess(m) m } @@ -209,7 +209,8 @@ class GBTClassificationModel private[ml]( @Since("1.6.0") override val uid: String, private val _trees: Array[DecisionTreeRegressionModel], private val _treeWeights: Array[Double], - @Since("1.6.0") override val numFeatures: Int) + @Since("1.6.0") override val numFeatures: Int, + @Since("2.2.0") override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, GBTClassificationModel] with GBTClassifierParams with TreeEnsembleModel[DecisionTreeRegressionModel] with MLWritable with Serializable { @@ -218,6 +219,20 @@ class GBTClassificationModel private[ml]( require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" + s" of non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).") + /** + * Construct a GBTClassificationModel + * + * @param _trees Decision trees in the ensemble. + * @param _treeWeights Weights for the decision trees in the ensemble. + * @param numFeatures The number of features. + */ + private[ml] def this( + uid: String, + _trees: Array[DecisionTreeRegressionModel], + _treeWeights: Array[Double], + numFeatures: Int) = + this(uid, _trees, _treeWeights, numFeatures, 2) + /** * Construct a GBTClassificationModel * @@ -226,7 +241,7 @@ class GBTClassificationModel private[ml]( */ @Since("1.6.0") def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) = - this(uid, _trees, _treeWeights, -1) + this(uid, _trees, _treeWeights, -1, 2) @Since("1.4.0") override def trees: Array[DecisionTreeRegressionModel] = _trees @@ -279,7 +294,7 @@ class GBTClassificationModel private[ml]( @Since("1.4.0") override def copy(extra: ParamMap): GBTClassificationModel = { - copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures), + copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures, numClasses), extra).setParent(parent) } @@ -377,7 +392,8 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { oldModel: OldGBTModel, parent: GBTClassifier, categoricalFeatures: Map[Int, Int], - numFeatures: Int = -1): GBTClassificationModel = { + numFeatures: Int = -1, + numClasses: Int = 2): GBTClassificationModel = { require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" + s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).") val newTrees = oldModel.trees.map { tree => @@ -385,6 +401,6 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] { DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc") - new GBTClassificationModel(uid, newTrees, oldModel.treeWeights, numFeatures) + new GBTClassificationModel(uid, newTrees, oldModel.treeWeights, numFeatures, numClasses) } }