Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._

import org.apache.spark.annotation.Since
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._
Expand All @@ -33,7 +32,6 @@ import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -69,6 +67,10 @@ class RandomForestClassifier @Since("1.4.0") (
@Since("1.4.0")
def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)

/** @group setParam */
@Since("3.0.0")
def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value)

/** @group setParam */
@Since("1.4.0")
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
Expand Down Expand Up @@ -118,6 +120,16 @@ class RandomForestClassifier @Since("1.4.0") (
def setFeatureSubsetStrategy(value: String): this.type =
set(featureSubsetStrategy, value)

/**
* Sets the value of param [[weightCol]].
* If this is not set or empty, we treat all instance weights as 1.0.
* By default the weightCol is not set, so all instances have weight 1.0.
*
* @group setParam
*/
@Since("3.0.0")
def setWeightCol(value: String): this.type = set(weightCol, value)

override protected def train(
dataset: Dataset[_]): RandomForestClassificationModel = instrumented { instr =>
instr.logPipelineStage(this)
Expand All @@ -132,14 +144,14 @@ class RandomForestClassifier @Since("1.4.0") (
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}

val instances: RDD[Instance] = extractLabeledPoints(dataset, numClasses).map(_.toInstance)
val instances = extractInstances(dataset, numClasses)
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)

instr.logParams(this, labelCol, featuresCol, predictionCol, probabilityCol, rawPredictionCol,
leafCol, impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB,
minInfoGain, minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds,
checkpointInterval)
instr.logParams(this, labelCol, featuresCol, weightCol, predictionCol, probabilityCol,
rawPredictionCol, leafCol, impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins,
maxMemoryInMB, minInfoGain, minInstancesPerNode, minWeightFractionPerNode, seed,
subsamplingRate, thresholds, cacheNodeIds, checkpointInterval)

val trees = RandomForest
.run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
@Since("1.4.0")
def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)

/** @group setParam */
@Since("3.0.0")
def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value)

/** @group setParam */
@Since("1.4.0")
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
Expand Down Expand Up @@ -113,20 +117,31 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
def setFeatureSubsetStrategy(value: String): this.type =
set(featureSubsetStrategy, value)

/**
* Sets the value of param [[weightCol]].
* If this is not set or empty, we treat all instance weights as 1.0.
* By default the weightCol is not set, so all instances have weight 1.0.
*
* @group setParam
*/
@Since("3.0.0")
def setWeightCol(value: String): this.type = set(weightCol, value)

override protected def train(
dataset: Dataset[_]): RandomForestRegressionModel = instrumented { instr =>
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))

val instances = extractLabeledPoints(dataset).map(_.toInstance)
val instances = extractInstances(dataset)
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity)

instr.logPipelineStage(this)
instr.logDataset(instances)
instr.logParams(this, labelCol, featuresCol, predictionCol, leafCol, impurity, numTrees,
featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain,
minInstancesPerNode, seed, subsamplingRate, cacheNodeIds, checkpointInterval)
instr.logParams(this, labelCol, featuresCol, weightCol, predictionCol, leafCol, impurity,
numTrees, featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain,
minInstancesPerNode, minWeightFractionPerNode, seed, subsamplingRate, cacheNodeIds,
checkpointInterval)

val trees = RandomForest
.run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ private[spark] object BaggedPoint {
seed: Long = Utils.random.nextLong()): RDD[BaggedPoint[Datum]] = {
// TODO: implement weighted bootstrapping
if (withReplacement) {
convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed)
convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples,
extractSampleWeight, seed)
} else {
if (numSubsamples == 1 && subsamplingRate == 1.0) {
convertToBaggedRDDWithoutSampling(input, extractSampleWeight)
Expand Down Expand Up @@ -104,6 +105,7 @@ private[spark] object BaggedPoint {
input: RDD[Datum],
subsample: Double,
numSubsamples: Int,
extractSampleWeight: (Datum => Double),
seed: Long): RDD[BaggedPoint[Datum]] = {
input.mapPartitionsWithIndex { (partitionIndex, instances) =>
// Use random seed = seed + partitionIndex + 1 to make generation reproducible.
Expand All @@ -116,7 +118,7 @@ private[spark] object BaggedPoint {
subsampleCounts(subsampleIndex) = poisson.sample()
subsampleIndex += 1
}
new BaggedPoint(instance, subsampleCounts)
new BaggedPoint(instance, subsampleCounts, extractSampleWeight(instance))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.ml.classification

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.classification.LinearSVCSuite.generateSVMInput
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
Expand All @@ -41,6 +42,8 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest {

private var orderedLabeledPoints50_1000: RDD[LabeledPoint] = _
private var orderedLabeledPoints5_20: RDD[LabeledPoint] = _
private var binaryDataset: DataFrame = _
private val seed = 42

override def beforeAll(): Unit = {
super.beforeAll()
Expand All @@ -50,6 +53,7 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest {
orderedLabeledPoints5_20 =
sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 5, 20))
.map(_.asML)
binaryDataset = generateSVMInput(0.01, Array[Double](-1.5, 1.0), 1000, seed).toDF()
}

/////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -259,6 +263,37 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest {
})
}

test("training with sample weights") {
val df = binaryDataset
val numClasses = 2
// (numTrees, maxDepth, subsamplingRate, fractionInTol)
val testParams = Seq(
(20, 5, 1.0, 0.96),
(20, 10, 1.0, 0.96),
(20, 10, 0.95, 0.96)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess maybe also add different impurity in testParams?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe also test a special case numTrees = 1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with numTrees==1, RF is exactly the DecisionTree, which is already tested in DecisionTreeClassifierSuite/DecisionTreeRegressorSuite.

I guess maybe also add different impurity in testParams?

I guess current tests maybe enough, Testsuites for DT/GBT do not test impurity.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I suggested testing different impurities is because when calculating best split, the impurity path (both entropy and gini) is affected by sample weight. However, after taking a look at the DecisionTree test, I saw both entropy and gini are tested with sample weight there, so this is already covered in DecisionTree test, no need to test here.


for ((numTrees, maxDepth, subsamplingRate, tol) <- testParams) {
val estimator = new RandomForestClassifier()
.setNumTrees(numTrees)
.setMaxDepth(maxDepth)
.setSubsamplingRate(subsamplingRate)
.setSeed(seed)
.setMinWeightFractionPerNode(0.049)

MLTestingUtils.testArbitrarilyScaledWeights[RandomForestClassificationModel,
RandomForestClassifier](df.as[LabeledPoint], estimator,
MLTestingUtils.modelPredictionEquals(df, _ == _, tol))
MLTestingUtils.testOutliersWithSmallWeights[RandomForestClassificationModel,
RandomForestClassifier](df.as[LabeledPoint], estimator,
numClasses, MLTestingUtils.modelPredictionEquals(df, _ == _, tol),
outlierRatio = 2)
MLTestingUtils.testOversamplingVsWeighting[RandomForestClassificationModel,
RandomForestClassifier](df.as[LabeledPoint], estimator,
MLTestingUtils.modelPredictionEquals(df, _ == _, tol), seed)
}
}

/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.util.LinearDataGenerator
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}

Expand All @@ -37,12 +39,18 @@ class RandomForestRegressorSuite extends MLTest with DefaultReadWriteTest{
import testImplicits._

private var orderedLabeledPoints50_1000: RDD[LabeledPoint] = _
private var linearRegressionData: DataFrame = _
private val seed = 42

override def beforeAll(): Unit = {
super.beforeAll()
orderedLabeledPoints50_1000 =
sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
.map(_.asML))

linearRegressionData = sc.parallelize(LinearDataGenerator.generateLinearInput(
intercept = 6.3, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3),
xVariance = Array(0.7, 1.2), nPoints = 1000, seed, eps = 0.5), 2).map(_.asML).toDF()
}

/////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -158,6 +166,37 @@ class RandomForestRegressorSuite extends MLTest with DefaultReadWriteTest{
})
}

test("training with sample weights") {
val df = linearRegressionData
val numClasses = 0
// (numTrees, maxDepth, subsamplingRate, fractionInTol)
val testParams = Seq(
(50, 5, 1.0, 0.75),
(50, 10, 1.0, 0.75),
(50, 10, 0.95, 0.78)
)

for ((numTrees, maxDepth, subsamplingRate, tol) <- testParams) {
val estimator = new RandomForestRegressor()
.setNumTrees(numTrees)
.setMaxDepth(maxDepth)
.setSubsamplingRate(subsamplingRate)
.setSeed(seed)
.setMinWeightFractionPerNode(0.05)

MLTestingUtils.testArbitrarilyScaledWeights[RandomForestRegressionModel,
RandomForestRegressor](df.as[LabeledPoint], estimator,
MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.2, tol))
MLTestingUtils.testOutliersWithSmallWeights[RandomForestRegressionModel,
RandomForestRegressor](df.as[LabeledPoint], estimator,
numClasses, MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.2, tol),
outlierRatio = 2)
MLTestingUtils.testOversamplingVsWeighting[RandomForestRegressionModel,
RandomForestRegressor](df.as[LabeledPoint], estimator,
MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.2, tol), seed)
}
}

/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext {
baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect()
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
expectedStddev, epsilon = 0.01)
// should ignore weight function for now
assert(baggedRDD.collect().forall(_.sampleWeight === 1.0))
assert(baggedRDD.collect().forall(_.sampleWeight === 2.0))
}
}

Expand Down
24 changes: 20 additions & 4 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -1387,6 +1387,8 @@ class RandomForestClassifier(JavaProbabilisticClassifier, _RandomForestClassifie
>>> td = si_model.transform(df)
>>> rf = RandomForestClassifier(numTrees=3, maxDepth=2, labelCol="indexed", seed=42,
... leafCol="leafId")
>>> rf.getMinWeightFractionPerNode()
0.0
>>> model = rf.fit(td)
>>> model.getLabelCol()
'indexed'
Expand Down Expand Up @@ -1441,14 +1443,14 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini",
numTrees=20, featureSubsetStrategy="auto", seed=None, subsamplingRate=1.0,
leafCol="", minWeightFractionPerNode=0.0):
leafCol="", minWeightFractionPerNode=0.0, weightCol=None):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
probabilityCol="probability", rawPredictionCol="rawPrediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
numTrees=20, featureSubsetStrategy="auto", seed=None, subsamplingRate=1.0, \
leafCol="", minWeightFractionPerNode=0.0)
leafCol="", minWeightFractionPerNode=0.0, weightCol=None)
"""
super(RandomForestClassifier, self).__init__()
self._java_obj = self._new_java_obj(
Expand All @@ -1467,14 +1469,14 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None,
impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0,
leafCol="", minWeightFractionPerNode=0.0):
leafCol="", minWeightFractionPerNode=0.0, weightCol=None):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
probabilityCol="probability", rawPredictionCol="rawPrediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, \
impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0, \
leafCol="", minWeightFractionPerNode=0.0)
leafCol="", minWeightFractionPerNode=0.0, weightCol=None)
Sets params for linear classification.
"""
kwargs = self._input_kwargs
Expand Down Expand Up @@ -1559,6 +1561,20 @@ def setCheckpointInterval(self, value):
"""
return self._set(checkpointInterval=value)

@since("3.0.0")
def setWeightCol(self, value):
"""
Sets the value of :py:attr:`weightCol`.
"""
return self._set(weightCol=value)

@since("3.0.0")
def setMinWeightFractionPerNode(self, value):
"""
Sets the value of :py:attr:`minWeightFractionPerNode`.
"""
return self._set(minWeightFractionPerNode=value)


class RandomForestClassificationModel(_TreeEnsembleModel, JavaProbabilisticClassificationModel,
_RandomForestClassifierParams, JavaMLWritable,
Expand Down
Loading