Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
08831e7
[SPARK-14975][ML][WIP] Fixed GBTClassifier to predict probability per…
imatiach-msft Dec 30, 2016
e73b60f
Fixed scala style empty line
imatiach-msft Dec 30, 2016
d29b70d
Fixed binary compatibility tests
imatiach-msft Dec 30, 2016
d4afdd0
Fixing GBT classifier based on comments
imatiach-msft Jan 3, 2017
62702c8
Fixing probabilities calculated from raw scores
imatiach-msft Jan 5, 2017
27882b3
fixed scala style, multiplied raw prediction value by 2 in prob estimate
imatiach-msft Jan 5, 2017
8698d16
Updating based on code review, including code cleanup and adding bett…
imatiach-msft Jan 6, 2017
aaf1b06
Adding back constructor but making it private
imatiach-msft Jan 6, 2017
bafab79
updates to GBTClassifier based on comments
imatiach-msft Jan 10, 2017
2a6dea4
minor fixes to scala style
imatiach-msft Jan 10, 2017
52c5115
Fixing more scala style
imatiach-msft Jan 10, 2017
609a1b0
Using getOldLossType as per comments
imatiach-msft Jan 10, 2017
a28afe6
Added more tests for thresholds, fixed minor bug in predict to use th…
imatiach-msft Jan 10, 2017
9d5bb9b
Updated based on newest comments
imatiach-msft Jan 10, 2017
89965f5
missed one arg
imatiach-msft Jan 10, 2017
cacbbc1
Moving arg to its own line
imatiach-msft Jan 10, 2017
7396dac
Updated based on latest comments - moved classifier loss trait, updat…
imatiach-msft Jan 11, 2017
f2e041d
Fixed up minor comments
imatiach-msft Jan 11, 2017
1abfee0
Updated based on comments from jkbradley
imatiach-msft Jan 18, 2017
818de81
Fixing build issues - need to keep numClasses in model
imatiach-msft Jan 18, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,16 @@ import org.json4s.JsonDSL._

import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.impl.GradientBoostedTrees
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.loss.LogLoss
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
Expand All @@ -58,7 +58,7 @@ import org.apache.spark.sql.functions._
@Since("1.4.0")
class GBTClassifier @Since("1.4.0") (
@Since("1.4.0") override val uid: String)
extends Predictor[Vector, GBTClassifier, GBTClassificationModel]
extends ProbabilisticClassifier[Vector, GBTClassifier, GBTClassificationModel]
with GBTClassifierParams with DefaultParamsWritable with Logging {

@Since("1.4.0")
Expand Down Expand Up @@ -158,12 +158,19 @@ class GBTClassifier @Since("1.4.0") (
val numFeatures = oldDataset.first().features.size
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)

val numClasses = 2
if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
".train() called with non-matching numClasses and thresholds.length." +
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}

val instr = Instrumentation.create(this, oldDataset)
instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType,
maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval)
instr.logNumFeatures(numFeatures)
instr.logNumClasses(2)
instr.logNumClasses(numClasses)

val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
$(seed))
Expand Down Expand Up @@ -202,15 +209,30 @@ class GBTClassificationModel private[ml](
@Since("1.6.0") override val uid: String,
private val _trees: Array[DecisionTreeRegressionModel],
private val _treeWeights: Array[Double],
@Since("1.6.0") override val numFeatures: Int)
extends PredictionModel[Vector, GBTClassificationModel]
@Since("1.6.0") override val numFeatures: Int,
@Since("2.2.0") override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, GBTClassificationModel]
with GBTClassifierParams with TreeEnsembleModel[DecisionTreeRegressionModel]
with MLWritable with Serializable {

require(_trees.nonEmpty, "GBTClassificationModel requires at least 1 tree.")
require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" +
s" of non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).")

/**
* Construct a GBTClassificationModel
*
* @param _trees Decision trees in the ensemble.
* @param _treeWeights Weights for the decision trees in the ensemble.
* @param numFeatures The number of features.
*/
private[ml] def this(
uid: String,
_trees: Array[DecisionTreeRegressionModel],
_treeWeights: Array[Double],
numFeatures: Int) =
this(uid, _trees, _treeWeights, numFeatures, 2)

/**
* Construct a GBTClassificationModel
*
Expand All @@ -219,7 +241,7 @@ class GBTClassificationModel private[ml](
*/
@Since("1.6.0")
def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) =
this(uid, _trees, _treeWeights, -1)
this(uid, _trees, _treeWeights, -1, 2)

@Since("1.4.0")
override def trees: Array[DecisionTreeRegressionModel] = _trees
Expand All @@ -242,19 +264,37 @@ class GBTClassificationModel private[ml](
}

override protected def predict(features: Vector): Double = {
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
// Classifies by thresholding sum of weighted tree predictions
val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
if (prediction > 0.0) 1.0 else 0.0
// If thresholds defined, use predictRaw to get probabilities, otherwise use optimization
if (isDefined(thresholds)) {
super.predict(features)
} else {
if (margin(features) > 0.0) 1.0 else 0.0
}
}

override protected def predictRaw(features: Vector): Vector = {
val prediction: Double = margin(features)
Vectors.dense(Array(-prediction, prediction))
}

override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
rawPrediction match {
case dv: DenseVector =>
dv.values(0) = loss.computeProbability(dv.values(0))
dv.values(1) = 1.0 - dv.values(0)
dv
case sv: SparseVector =>
throw new RuntimeException("Unexpected error in GBTClassificationModel:" +
" raw2probabilityInPlace encountered SparseVector")
}
}

/** Number of trees in ensemble */
val numTrees: Int = trees.length

@Since("1.4.0")
override def copy(extra: ParamMap): GBTClassificationModel = {
copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures),
copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures, numClasses),
extra).setParent(parent)
}

Expand All @@ -276,18 +316,30 @@ class GBTClassificationModel private[ml](
@Since("2.0.0")
lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures)

/** Raw prediction for the positive class. */
private def margin(features: Vector): Double = {
val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
}

/** (private[ml]) Convert to a model in the old API */
private[ml] def toOld: OldGBTModel = {
new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights)
}

// hard coded loss, which is not meant to be changed in the model
private val loss = getOldLossType

@Since("2.0.0")
override def write: MLWriter = new GBTClassificationModel.GBTClassificationModelWriter(this)
}

@Since("2.0.0")
object GBTClassificationModel extends MLReadable[GBTClassificationModel] {

private val numFeaturesKey: String = "numFeatures"
private val numTreesKey: String = "numTrees"

@Since("2.0.0")
override def read: MLReader[GBTClassificationModel] = new GBTClassificationModelReader

Expand All @@ -300,8 +352,8 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
override protected def saveImpl(path: String): Unit = {

val extraMetadata: JObject = Map(
"numFeatures" -> instance.numFeatures,
"numTrees" -> instance.getNumTrees)
numFeaturesKey -> instance.numFeatures,
numTreesKey -> instance.getNumTrees)
EnsembleModelReadWrite.saveImpl(instance, path, sparkSession, extraMetadata)
}
}
Expand All @@ -316,8 +368,8 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
implicit val format = DefaultFormats
val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val numTrees = (metadata.metadata \ "numTrees").extract[Int]
val numFeatures = (metadata.metadata \ numFeaturesKey).extract[Int]
val numTrees = (metadata.metadata \ numTreesKey).extract[Int]

val trees: Array[DecisionTreeRegressionModel] = treesData.map {
case (treeMetadata, root) =>
Expand All @@ -328,7 +380,8 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
}
require(numTrees == trees.length, s"GBTClassificationModel.load expected $numTrees" +
s" trees based on metadata but found ${trees.length} trees.")
val model = new GBTClassificationModel(metadata.uid, trees, treeWeights, numFeatures)
val model = new GBTClassificationModel(metadata.uid,
trees, treeWeights, numFeatures)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
Expand All @@ -339,14 +392,15 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
oldModel: OldGBTModel,
parent: GBTClassifier,
categoricalFeatures: Map[Int, Int],
numFeatures: Int = -1): GBTClassificationModel = {
numFeatures: Int = -1,
numClasses: Int = 2): GBTClassificationModel = {
require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel" +
s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).")
val newTrees = oldModel.trees.map { tree =>
// parent for each tree is null since there is no good way to set this.
DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
}
val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc")
new GBTClassificationModel(uid, newTrees, oldModel.treeWeights, numFeatures)
new GBTClassificationModel(uid, newTrees, oldModel.treeWeights, numFeatures, numClasses)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance}
import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError}
import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, ClassificationLoss => OldClassificationLoss, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError}
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}

/**
Expand Down Expand Up @@ -531,7 +531,7 @@ private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParam
def getLossType: String = $(lossType).toLowerCase

/** (private[ml]) Convert new loss to old loss. */
override private[ml] def getOldLossType: OldLoss = {
override private[ml] def getOldLossType: OldClassificationLoss = {
getLossType match {
case "logistic" => OldLogLoss
case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.mllib.tree.loss
import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.mllib.util.MLUtils


/**
* :: DeveloperApi ::
* Class for log loss calculation (for classification).
Expand All @@ -32,7 +31,7 @@ import org.apache.spark.mllib.util.MLUtils
*/
@Since("1.2.0")
@DeveloperApi
object LogLoss extends Loss {
object LogLoss extends ClassificationLoss {

/**
* Method to calculate the loss gradients for the gradient boosting calculation for binary
Expand All @@ -52,4 +51,11 @@ object LogLoss extends Loss {
// The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
2.0 * MLUtils.log1pExp(-margin)
}

/**
* Returns the estimated probability of a label of 1.0.
*/
override private[spark] def computeProbability(margin: Double): Double = {
1.0 / (1.0 + math.exp(-2.0 * margin))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.model.TreeEnsembleModel
import org.apache.spark.rdd.RDD


/**
* :: DeveloperApi ::
* Trait for adding "pluggable" loss functions for the gradient boosting algorithm.
Expand Down Expand Up @@ -67,3 +66,10 @@ trait Loss extends Serializable {
*/
private[spark] def computeError(prediction: Double, label: Double): Double
}

private[spark] trait ClassificationLoss extends Loss {
/**
* Computes the class probability given the margin.
*/
private[spark] def computeProbability(margin: Double): Double
}
Loading