diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala b/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
index 6c79d77f142e..fb98bcb9d6b0 100644
--- a/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
+++ b/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
@@ -52,7 +52,7 @@ object TestingUtils {
/**
* Private helper function for comparing two values using absolute tolerance.
*/
- private def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = {
+ private[ml] def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = {
// Special case for NaNs
if (x.isNaN && y.isNaN) {
return true
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index 7e5790ab70ee..e35e6ce7fdad 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -77,17 +77,37 @@ abstract class Classifier[
* @note Throws `SparkException` if any label is a non-integer or is negative
*/
protected def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = {
- require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" +
- s" $numClasses, but requires numClasses > 0.")
+ validateNumClasses(numClasses)
dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
case Row(label: Double, features: Vector) =>
- require(label % 1 == 0 && label >= 0 && label < numClasses, s"Classifier was given" +
- s" dataset with invalid label $label. Labels must be integers in range" +
- s" [0, $numClasses).")
+ validateLabel(label, numClasses)
LabeledPoint(label, features)
}
}
+ /**
+ * Validates that number of classes is greater than zero.
+ *
+ * @param numClasses Number of classes label can take.
+ */
+ protected def validateNumClasses(numClasses: Int): Unit = {
+ require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" +
+ s" $numClasses, but requires numClasses > 0.")
+ }
+
+ /**
+ * Validates the label on the classifier is a valid integer in the range [0, numClasses).
+ *
+ * @param label The label to validate.
+ * @param numClasses Number of classes label can take. Labels must be integers in the range
+ * [0, numClasses).
+ */
+ protected def validateLabel(label: Double, numClasses: Int): Unit = {
+ require(label.toLong == label && label >= 0 && label < numClasses, s"Classifier was given" +
+ s" dataset with invalid label $label. Labels must be integers in range" +
+ s" [0, $numClasses).")
+ }
+
/**
* Get the number of classes. This looks in column metadata first, and if that is missing,
* then this assumes classes are indexed 0,1,...,numClasses-1 and computes numClasses
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index d9292a547676..200ac0032e51 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -22,10 +22,12 @@ import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._
import org.apache.spark.annotation.Since
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.tree._
+import org.apache.spark.ml.tree.{DecisionTreeModel, Node, TreeClassifierParams}
import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._
import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.util._
@@ -33,8 +35,9 @@ import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.Dataset
-
+import org.apache.spark.sql.{Dataset, Row}
+import org.apache.spark.sql.functions.{col, lit}
+import org.apache.spark.sql.types.DoubleType
/**
* Decision tree learning algorithm (http://en.wikipedia.org/wiki/Decision_tree_learning)
@@ -66,6 +69,9 @@ class DecisionTreeClassifier @Since("1.4.0") (
def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
/** @group setParam */
+ @Since("3.0.0")
+ def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value)
+
@Since("1.4.0")
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
@@ -97,6 +103,16 @@ class DecisionTreeClassifier @Since("1.4.0") (
@Since("1.6.0")
def setSeed(value: Long): this.type = set(seed, value)
+ /**
+ * Sets the value of param [[weightCol]].
+ * If this is not set or empty, we treat all instance weights as 1.0.
+ * Default is not set, so all instances have weight one.
+ *
+ * @group setParam
+ */
+ @Since("3.0.0")
+ def setWeightCol(value: String): this.type = set(weightCol, value)
+
override protected def train(
dataset: Dataset[_]): DecisionTreeClassificationModel = instrumented { instr =>
instr.logPipelineStage(this)
@@ -104,22 +120,27 @@ class DecisionTreeClassifier @Since("1.4.0") (
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses: Int = getNumClasses(dataset)
- instr.logNumClasses(numClasses)
if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
".train() called with non-matching numClasses and thresholds.length." +
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}
-
- val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
+ validateNumClasses(numClasses)
+ val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
+ val instances =
+ dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
+ case Row(label: Double, weight: Double, features: Vector) =>
+ validateLabel(label, numClasses)
+ Instance(label, weight, features)
+ }
val strategy = getOldStrategy(categoricalFeatures, numClasses)
-
+ instr.logNumClasses(numClasses)
instr.logParams(this, labelCol, featuresCol, predictionCol, rawPredictionCol,
probabilityCol, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB,
cacheNodeIds, checkpointInterval, impurity, seed)
- val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
+ val trees = RandomForest.run(instances, strategy, numTrees = 1, featureSubsetStrategy = "all",
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
trees.head.asInstanceOf[DecisionTreeClassificationModel]
@@ -128,13 +149,13 @@ class DecisionTreeClassifier @Since("1.4.0") (
/** (private[ml]) Train a decision tree on an RDD */
private[ml] def train(data: RDD[LabeledPoint],
oldStrategy: OldStrategy): DecisionTreeClassificationModel = instrumented { instr =>
+ val instances = data.map(_.toInstance)
instr.logPipelineStage(this)
- instr.logDataset(data)
+ instr.logDataset(instances)
instr.logParams(this, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB,
cacheNodeIds, checkpointInterval, impurity, seed)
-
- val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all",
- seed = 0L, instr = Some(instr), parentUID = Some(uid))
+ val trees = RandomForest.run(instances, oldStrategy, numTrees = 1,
+ featureSubsetStrategy = "all", seed = 0L, instr = Some(instr), parentUID = Some(uid))
trees.head.asInstanceOf[DecisionTreeClassificationModel]
}
@@ -180,6 +201,7 @@ class DecisionTreeClassificationModel private[ml] (
/**
* Construct a decision tree classification model.
+ *
* @param rootNode Root node of tree, with other nodes attached.
*/
private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) =
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 0a3bfd1f85e0..3500f2ad52a5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -21,20 +21,21 @@ import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._
import org.apache.spark.annotation.Since
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._
+import org.apache.spark.ml.tree.{RandomForestParams, TreeClassifierParams, TreeEnsembleModel}
import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
-import org.apache.spark.sql.functions._
-
+import org.apache.spark.sql.functions.{col, udf}
/**
* Random Forest learning algorithm for
@@ -130,7 +131,7 @@ class RandomForestClassifier @Since("1.4.0") (
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}
- val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
+ val instances: RDD[Instance] = extractLabeledPoints(dataset, numClasses).map(_.toInstance)
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
@@ -139,7 +140,7 @@ class RandomForestClassifier @Since("1.4.0") (
minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds, checkpointInterval)
val trees = RandomForest
- .run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
+ .run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
.map(_.asInstanceOf[DecisionTreeClassificationModel])
val numFeatures = trees.head.numFeatures
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala
index 412954f7b2d5..f6667b73304a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala
@@ -37,4 +37,13 @@ case class LabeledPoint(@Since("2.0.0") label: Double, @Since("2.0.0") features:
override def toString: String = {
s"($label,$features)"
}
+
+ private[spark] def toInstance(weight: Double): Instance = {
+ Instance(label, weight, features)
+ }
+
+ private[spark] def toInstance: Instance = {
+ Instance(label, 1.0, features)
+ }
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index faadc4d7b4cc..525479151aa0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -23,9 +23,10 @@ import org.json4s.JsonDSL._
import org.apache.spark.annotation.Since
import org.apache.spark.ml.{PredictionModel, Predictor}
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._
import org.apache.spark.ml.tree.impl.RandomForest
@@ -34,8 +35,9 @@ import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.DoubleType
/**
@@ -65,6 +67,9 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
/** @group setParam */
+ @Since("3.0.0")
+ def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value)
+
@Since("1.4.0")
def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
@@ -100,18 +105,33 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
@Since("2.0.0")
def setVarianceCol(value: String): this.type = set(varianceCol, value)
+ /**
+ * Sets the value of param [[weightCol]].
+ * If this is not set or empty, we treat all instance weights as 1.0.
+ * Default is not set, so all instances have weight one.
+ *
+ * @group setParam
+ */
+ @Since("3.0.0")
+ def setWeightCol(value: String): this.type = set(weightCol, value)
+
override protected def train(
dataset: Dataset[_]): DecisionTreeRegressionModel = instrumented { instr =>
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
- val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
+ val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
+ val instances =
+ dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
+ case Row(label: Double, weight: Double, features: Vector) =>
+ Instance(label, weight, features)
+ }
val strategy = getOldStrategy(categoricalFeatures)
instr.logPipelineStage(this)
- instr.logDataset(oldDataset)
+ instr.logDataset(instances)
instr.logParams(this, params: _*)
- val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
+ val trees = RandomForest.run(instances, strategy, numTrees = 1, featureSubsetStrategy = "all",
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
trees.head.asInstanceOf[DecisionTreeRegressionModel]
@@ -126,8 +146,9 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
instr.logDataset(data)
instr.logParams(this, params: _*)
- val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy,
- seed = $(seed), instr = Some(instr), parentUID = Some(uid))
+ val instances = data.map(_.toInstance)
+ val trees = RandomForest.run(instances, oldStrategy, numTrees = 1,
+ featureSubsetStrategy, seed = $(seed), instr = Some(instr), parentUID = Some(uid))
trees.head.asInstanceOf[DecisionTreeRegressionModel]
}
@@ -155,6 +176,7 @@ object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor
*
* Decision tree (Wikipedia) model for regression.
* It supports both continuous and categorical features.
+ *
* @param rootNode Root of the decision tree
*/
@Since("1.4.0")
@@ -173,6 +195,7 @@ class DecisionTreeRegressionModel private[ml] (
/**
* Construct a decision tree regression model.
+ *
* @param rootNode Root node of tree, with other nodes attached.
*/
private[ml] def this(rootNode: Node, numFeatures: Int) =
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index afa9a646412b..6f36dfb9ff51 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -22,7 +22,6 @@ import org.json4s.JsonDSL._
import org.apache.spark.annotation.Since
import org.apache.spark.ml.{PredictionModel, Predictor}
-import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._
@@ -32,10 +31,8 @@ import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
-import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
-import org.apache.spark.sql.functions._
-
+import org.apache.spark.sql.functions.{col, udf}
/**
* Random Forest
@@ -119,18 +116,19 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
dataset: Dataset[_]): RandomForestRegressionModel = instrumented { instr =>
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
- val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
+
+ val instances = extractLabeledPoints(dataset).map(_.toInstance)
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity)
instr.logPipelineStage(this)
- instr.logDataset(dataset)
+ instr.logDataset(instances)
instr.logParams(this, labelCol, featuresCol, predictionCol, impurity, numTrees,
featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain,
minInstancesPerNode, seed, subsamplingRate, cacheNodeIds, checkpointInterval)
val trees = RandomForest
- .run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
+ .run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
.map(_.asInstanceOf[DecisionTreeRegressionModel])
val numFeatures = trees.head.numFeatures
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala
index 4e372702f0c6..c896b1589a93 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala
@@ -33,13 +33,13 @@ import org.apache.spark.util.random.XORShiftRandom
* this datum has 1 copy, 0 copies, and 4 copies in the 3 subsamples, respectively.
*
* @param datum Data instance
- * @param subsampleWeights Weight of this instance in each subsampled dataset.
- *
- * TODO: This does not currently support (Double) weighted instances. Once MLlib has weighted
- * dataset support, update. (We store subsampleWeights as Double for this future extension.)
+ * @param subsampleCounts Number of samples of this instance in each subsampled dataset.
+ * @param sampleWeight The weight of this instance.
*/
-private[spark] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double])
- extends Serializable
+private[spark] class BaggedPoint[Datum](
+ val datum: Datum,
+ val subsampleCounts: Array[Int],
+ val sampleWeight: Double = 1.0) extends Serializable
private[spark] object BaggedPoint {
@@ -52,6 +52,7 @@ private[spark] object BaggedPoint {
* @param subsamplingRate Fraction of the training data used for learning decision tree.
* @param numSubsamples Number of subsamples of this RDD to take.
* @param withReplacement Sampling with/without replacement.
+ * @param extractSampleWeight A function to get the sample weight of each datum.
* @param seed Random seed.
* @return BaggedPoint dataset representation.
*/
@@ -60,12 +61,14 @@ private[spark] object BaggedPoint {
subsamplingRate: Double,
numSubsamples: Int,
withReplacement: Boolean,
+ extractSampleWeight: (Datum => Double) = (_: Datum) => 1.0,
seed: Long = Utils.random.nextLong()): RDD[BaggedPoint[Datum]] = {
+ // TODO: implement weighted bootstrapping
if (withReplacement) {
convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed)
} else {
if (numSubsamples == 1 && subsamplingRate == 1.0) {
- convertToBaggedRDDWithoutSampling(input)
+ convertToBaggedRDDWithoutSampling(input, extractSampleWeight)
} else {
convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed)
}
@@ -82,16 +85,15 @@ private[spark] object BaggedPoint {
val rng = new XORShiftRandom
rng.setSeed(seed + partitionIndex + 1)
instances.map { instance =>
- val subsampleWeights = new Array[Double](numSubsamples)
+ val subsampleCounts = new Array[Int](numSubsamples)
var subsampleIndex = 0
while (subsampleIndex < numSubsamples) {
- val x = rng.nextDouble()
- subsampleWeights(subsampleIndex) = {
- if (x < subsamplingRate) 1.0 else 0.0
+ if (rng.nextDouble() < subsamplingRate) {
+ subsampleCounts(subsampleIndex) = 1
}
subsampleIndex += 1
}
- new BaggedPoint(instance, subsampleWeights)
+ new BaggedPoint(instance, subsampleCounts)
}
}
}
@@ -106,20 +108,20 @@ private[spark] object BaggedPoint {
val poisson = new PoissonDistribution(subsample)
poisson.reseedRandomGenerator(seed + partitionIndex + 1)
instances.map { instance =>
- val subsampleWeights = new Array[Double](numSubsamples)
+ val subsampleCounts = new Array[Int](numSubsamples)
var subsampleIndex = 0
while (subsampleIndex < numSubsamples) {
- subsampleWeights(subsampleIndex) = poisson.sample()
+ subsampleCounts(subsampleIndex) = poisson.sample()
subsampleIndex += 1
}
- new BaggedPoint(instance, subsampleWeights)
+ new BaggedPoint(instance, subsampleCounts)
}
}
}
private def convertToBaggedRDDWithoutSampling[Datum] (
- input: RDD[Datum]): RDD[BaggedPoint[Datum]] = {
- input.map(datum => new BaggedPoint(datum, Array(1.0)))
+ input: RDD[Datum],
+ extractSampleWeight: (Datum => Double)): RDD[BaggedPoint[Datum]] = {
+ input.map(datum => new BaggedPoint(datum, Array(1), extractSampleWeight(datum)))
}
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala
index 5aeea1443d49..17ec161f2f50 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala
@@ -104,16 +104,21 @@ private[spark] class DTStatsAggregator(
/**
* Update the stats for a given (feature, bin) for ordered features, using the given label.
*/
- def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = {
+ def update(
+ featureIndex: Int,
+ binIndex: Int,
+ label: Double,
+ numSamples: Int,
+ sampleWeight: Double): Unit = {
val i = featureOffsets(featureIndex) + binIndex * statsSize
- impurityAggregator.update(allStats, i, label, instanceWeight)
+ impurityAggregator.update(allStats, i, label, numSamples, sampleWeight)
}
/**
* Update the parent node stats using the given label.
*/
- def updateParent(label: Double, instanceWeight: Double): Unit = {
- impurityAggregator.update(parentStats, 0, label, instanceWeight)
+ def updateParent(label: Double, numSamples: Int, sampleWeight: Double): Unit = {
+ impurityAggregator.update(parentStats, 0, label, numSamples, sampleWeight)
}
/**
@@ -127,9 +132,10 @@ private[spark] class DTStatsAggregator(
featureOffset: Int,
binIndex: Int,
label: Double,
- instanceWeight: Double): Unit = {
+ numSamples: Int,
+ sampleWeight: Double): Unit = {
impurityAggregator.update(allStats, featureOffset + binIndex * statsSize,
- label, instanceWeight)
+ label, numSamples, sampleWeight)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
index 53189e0797b6..8f8a17171f98 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable
import scala.util.Try
import org.apache.spark.internal.Logging
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.tree.TreeEnsembleParams
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
@@ -32,16 +32,20 @@ import org.apache.spark.rdd.RDD
/**
* Learning and dataset metadata for DecisionTree.
*
+ * @param weightedNumExamples Weighted count of samples in the tree.
* @param numClasses For classification: labels can take values {0, ..., numClasses - 1}.
* For regression: fixed at 0 (no meaning).
* @param maxBins Maximum number of bins, for all features.
* @param featureArity Map: categorical feature index to arity.
* I.e., the feature takes values in {0, ..., arity - 1}.
* @param numBins Number of bins for each feature.
+ * @param minWeightFractionPerNode The minimum fraction of the total sample weight that must be
+ * present in a leaf node in order to be considered a valid split.
*/
private[spark] class DecisionTreeMetadata(
val numFeatures: Int,
val numExamples: Long,
+ val weightedNumExamples: Double,
val numClasses: Int,
val maxBins: Int,
val featureArity: Map[Int, Int],
@@ -51,6 +55,7 @@ private[spark] class DecisionTreeMetadata(
val quantileStrategy: QuantileStrategy,
val maxDepth: Int,
val minInstancesPerNode: Int,
+ val minWeightFractionPerNode: Double,
val minInfoGain: Double,
val numTrees: Int,
val numFeaturesPerNode: Int) extends Serializable {
@@ -67,6 +72,8 @@ private[spark] class DecisionTreeMetadata(
def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex)
+ def minWeightPerNode: Double = minWeightFractionPerNode * weightedNumExamples
+
/**
* Number of splits for the given feature.
* For unordered features, there is 1 bin per split.
@@ -104,7 +111,7 @@ private[spark] object DecisionTreeMetadata extends Logging {
* as well as the number of splits and bins for each feature.
*/
def buildMetadata(
- input: RDD[LabeledPoint],
+ input: RDD[Instance],
strategy: Strategy,
numTrees: Int,
featureSubsetStrategy: String): DecisionTreeMetadata = {
@@ -115,7 +122,11 @@ private[spark] object DecisionTreeMetadata extends Logging {
}
require(numFeatures > 0, s"DecisionTree requires number of features > 0, " +
s"but was given an empty features vector")
- val numExamples = input.count()
+ val (numExamples, weightSum) = input.aggregate((0L, 0.0))(
+ seqOp = (cw, instance) => (cw._1 + 1L, cw._2 + instance.weight),
+ combOp = (cw1, cw2) => (cw1._1 + cw2._1, cw1._2 + cw2._2)
+ )
+
val numClasses = strategy.algo match {
case Classification => strategy.numClasses
case Regression => 0
@@ -206,17 +217,18 @@ private[spark] object DecisionTreeMetadata extends Logging {
}
}
- new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
- strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
+ new DecisionTreeMetadata(numFeatures, numExamples, weightSum, numClasses,
+ numBins.max, strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth,
- strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode)
+ strategy.minInstancesPerNode, strategy.minWeightFractionPerNode, strategy.minInfoGain,
+ numTrees, numFeaturesPerNode)
}
/**
* Version of [[DecisionTreeMetadata#buildMetadata]] for DecisionTree.
*/
def buildMetadata(
- input: RDD[LabeledPoint],
+ input: RDD[Instance],
strategy: Strategy): DecisionTreeMetadata = {
buildMetadata(input, strategy, numTrees = 1, featureSubsetStrategy = "all")
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 822abd2d3522..fb4c321a146f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -24,10 +24,12 @@ import scala.util.Random
import org.apache.spark.internal.Logging
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.Instance
+import org.apache.spark.ml.impl.Utils
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree._
import org.apache.spark.ml.util.Instrumentation
+import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
import org.apache.spark.mllib.tree.model.ImpurityStats
@@ -90,6 +92,24 @@ private[spark] object RandomForest extends Logging with Serializable {
strategy: OldStrategy,
numTrees: Int,
featureSubsetStrategy: String,
+ seed: Long): Array[DecisionTreeModel] = {
+ val instances = input.map { case LabeledPoint(label, features) =>
+ Instance(label, 1.0, features.asML)
+ }
+ run(instances, strategy, numTrees, featureSubsetStrategy, seed, None)
+ }
+
+ /**
+ * Train a random forest.
+ *
+ * @param input Training data: RDD of `Instance`
+ * @return an unweighted set of trees
+ */
+ def run(
+ input: RDD[Instance],
+ strategy: OldStrategy,
+ numTrees: Int,
+ featureSubsetStrategy: String,
seed: Long,
instr: Option[Instrumentation],
prune: Boolean = true, // exposed for testing only, real trees are always pruned
@@ -101,9 +121,10 @@ private[spark] object RandomForest extends Logging with Serializable {
timer.start("init")
- val retaggedInput = input.retag(classOf[LabeledPoint])
+ val retaggedInput = input.retag(classOf[Instance])
val metadata =
DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy)
+
instr match {
case Some(instrumentation) =>
instrumentation.logNumFeatures(metadata.numFeatures)
@@ -132,7 +153,8 @@ private[spark] object RandomForest extends Logging with Serializable {
val withReplacement = numTrees > 1
val baggedInput = BaggedPoint
- .convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, withReplacement, seed)
+ .convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, withReplacement,
+ (tp: TreePoint) => tp.weight, seed = seed)
.persist(StorageLevel.MEMORY_AND_DISK)
// depth of the decision tree
@@ -254,19 +276,21 @@ private[spark] object RandomForest extends Logging with Serializable {
* For unordered features, bins correspond to subsets of categories; either the left or right bin
* for each subset is updated.
*
- * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
- * each (feature, bin).
- * @param treePoint Data point being aggregated.
- * @param splits possible splits indexed (numFeatures)(numSplits)
- * @param unorderedFeatures Set of indices of unordered features.
- * @param instanceWeight Weight (importance) of instance in dataset.
+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
+ * each (feature, bin).
+ * @param treePoint Data point being aggregated.
+ * @param splits Possible splits indexed (numFeatures)(numSplits)
+ * @param unorderedFeatures Set of indices of unordered features.
+ * @param numSamples Number of times this instance occurs in the sample.
+ * @param sampleWeight Weight (importance) of instance in dataset.
*/
private def mixedBinSeqOp(
agg: DTStatsAggregator,
treePoint: TreePoint,
splits: Array[Array[Split]],
unorderedFeatures: Set[Int],
- instanceWeight: Double,
+ numSamples: Int,
+ sampleWeight: Double,
featuresForNode: Option[Array[Int]]): Unit = {
val numFeaturesPerNode = if (featuresForNode.nonEmpty) {
// Use subsampled features
@@ -293,14 +317,15 @@ private[spark] object RandomForest extends Logging with Serializable {
var splitIndex = 0
while (splitIndex < numSplits) {
if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) {
- agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight)
+ agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, numSamples,
+ sampleWeight)
}
splitIndex += 1
}
} else {
// Ordered feature
val binIndex = treePoint.binnedFeatures(featureIndex)
- agg.update(featureIndexIdx, binIndex, treePoint.label, instanceWeight)
+ agg.update(featureIndexIdx, binIndex, treePoint.label, numSamples, sampleWeight)
}
featureIndexIdx += 1
}
@@ -314,12 +339,14 @@ private[spark] object RandomForest extends Logging with Serializable {
* @param agg Array storing aggregate calculation, with a set of sufficient statistics for
* each (feature, bin).
* @param treePoint Data point being aggregated.
- * @param instanceWeight Weight (importance) of instance in dataset.
+ * @param numSamples Number of times this instance occurs in the sample.
+ * @param sampleWeight Weight (importance) of instance in dataset.
*/
private def orderedBinSeqOp(
agg: DTStatsAggregator,
treePoint: TreePoint,
- instanceWeight: Double,
+ numSamples: Int,
+ sampleWeight: Double,
featuresForNode: Option[Array[Int]]): Unit = {
val label = treePoint.label
@@ -329,7 +356,7 @@ private[spark] object RandomForest extends Logging with Serializable {
var featureIndexIdx = 0
while (featureIndexIdx < featuresForNode.get.length) {
val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx))
- agg.update(featureIndexIdx, binIndex, label, instanceWeight)
+ agg.update(featureIndexIdx, binIndex, label, numSamples, sampleWeight)
featureIndexIdx += 1
}
} else {
@@ -338,7 +365,7 @@ private[spark] object RandomForest extends Logging with Serializable {
var featureIndex = 0
while (featureIndex < numFeatures) {
val binIndex = treePoint.binnedFeatures(featureIndex)
- agg.update(featureIndex, binIndex, label, instanceWeight)
+ agg.update(featureIndex, binIndex, label, numSamples, sampleWeight)
featureIndex += 1
}
}
@@ -427,14 +454,16 @@ private[spark] object RandomForest extends Logging with Serializable {
if (nodeInfo != null) {
val aggNodeIndex = nodeInfo.nodeIndexInGroup
val featuresForNode = nodeInfo.featureSubset
- val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
+ val numSamples = baggedPoint.subsampleCounts(treeIndex)
+ val sampleWeight = baggedPoint.sampleWeight
if (metadata.unorderedFeatures.isEmpty) {
- orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
+ orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, numSamples, sampleWeight,
+ featuresForNode)
} else {
mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
- metadata.unorderedFeatures, instanceWeight, featuresForNode)
+ metadata.unorderedFeatures, numSamples, sampleWeight, featuresForNode)
}
- agg(aggNodeIndex).updateParent(baggedPoint.datum.label, instanceWeight)
+ agg(aggNodeIndex).updateParent(baggedPoint.datum.label, numSamples, sampleWeight)
}
}
@@ -594,8 +623,8 @@ private[spark] object RandomForest extends Logging with Serializable {
if (!isLeaf) {
node.split = Some(split)
val childIsLeaf = (LearningNode.indexToLevel(nodeIndex) + 1) == metadata.maxDepth
- val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0)
- val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0)
+ val leftChildIsLeaf = childIsLeaf || (math.abs(stats.leftImpurity) < Utils.EPSILON)
+ val rightChildIsLeaf = childIsLeaf || (math.abs(stats.rightImpurity) < Utils.EPSILON)
node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex),
leftChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator)))
node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex),
@@ -659,15 +688,20 @@ private[spark] object RandomForest extends Logging with Serializable {
stats.impurity
}
+ val leftRawCount = leftImpurityCalculator.rawCount
+ val rightRawCount = rightImpurityCalculator.rawCount
val leftCount = leftImpurityCalculator.count
val rightCount = rightImpurityCalculator.count
val totalCount = leftCount + rightCount
- // If left child or right child doesn't satisfy minimum instances per node,
- // then this split is invalid, return invalid information gain stats.
- if ((leftCount < metadata.minInstancesPerNode) ||
- (rightCount < metadata.minInstancesPerNode)) {
+ val violatesMinInstancesPerNode = (leftRawCount < metadata.minInstancesPerNode) ||
+ (rightRawCount < metadata.minInstancesPerNode)
+ val violatesMinWeightPerNode = (leftCount < metadata.minWeightPerNode) ||
+ (rightCount < metadata.minWeightPerNode)
+ // If left child or right child doesn't satisfy minimum weight per node or minimum
+ // instances per node, then this split is invalid, return invalid information gain stats.
+ if (violatesMinInstancesPerNode || violatesMinWeightPerNode) {
return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
}
@@ -734,7 +768,8 @@ private[spark] object RandomForest extends Logging with Serializable {
// Find best split.
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { case splitIdx =>
- val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
+ val leftChildStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
val rightChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
rightChildStats.subtract(leftChildStats)
@@ -876,14 +911,14 @@ private[spark] object RandomForest extends Logging with Serializable {
* and for multiclass classification with a high-arity feature,
* there is one bin per category.
*
- * @param input Training data: RDD of [[LabeledPoint]]
+ * @param input Training data: RDD of [[Instance]]
* @param metadata Learning and dataset metadata
* @param seed random seed
* @return Splits, an Array of [[Split]]
* of size (numFeatures, numSplits)
*/
protected[tree] def findSplits(
- input: RDD[LabeledPoint],
+ input: RDD[Instance],
metadata: DecisionTreeMetadata,
seed: Long): Array[Array[Split]] = {
@@ -898,14 +933,14 @@ private[spark] object RandomForest extends Logging with Serializable {
logDebug("fraction of data used for calculating quantiles = " + fraction)
input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt())
} else {
- input.sparkContext.emptyRDD[LabeledPoint]
+ input.sparkContext.emptyRDD[Instance]
}
findSplitsBySorting(sampledInput, metadata, continuousFeatures)
}
private def findSplitsBySorting(
- input: RDD[LabeledPoint],
+ input: RDD[Instance],
metadata: DecisionTreeMetadata,
continuousFeatures: IndexedSeq[Int]): Array[Array[Split]] = {
@@ -917,7 +952,8 @@ private[spark] object RandomForest extends Logging with Serializable {
input
.flatMap { point =>
- continuousFeatures.map(idx => (idx, point.features(idx))).filter(_._2 != 0.0)
+ continuousFeatures.map(idx => (idx, (point.weight, point.features(idx))))
+ .filter(_._2._2 != 0.0)
}.groupByKey(numPartitions)
.map { case (idx, samples) =>
val thresholds = findSplitsForContinuousFeature(samples, metadata, idx)
@@ -982,7 +1018,7 @@ private[spark] object RandomForest extends Logging with Serializable {
* could be different from the specified `numSplits`.
* The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly.
*
- * @param featureSamples feature values of each sample
+ * @param featureSamples feature values and sample weights of each sample
* @param metadata decision tree metadata
* NOTE: `metadata.numbins` will be changed accordingly
* if there are not enough splits to be found
@@ -990,7 +1026,7 @@ private[spark] object RandomForest extends Logging with Serializable {
* @return array of split thresholds
*/
private[tree] def findSplitsForContinuousFeature(
- featureSamples: Iterable[Double],
+ featureSamples: Iterable[(Double, Double)],
metadata: DecisionTreeMetadata,
featureIndex: Int): Array[Double] = {
require(metadata.isContinuous(featureIndex),
@@ -1002,19 +1038,24 @@ private[spark] object RandomForest extends Logging with Serializable {
val numSplits = metadata.numSplits(featureIndex)
// get count for each distinct value except zero value
- val partNumSamples = featureSamples.size
- val partValueCountMap = scala.collection.mutable.Map[Double, Int]()
- featureSamples.foreach { x =>
- partValueCountMap(x) = partValueCountMap.getOrElse(x, 0) + 1
+ val partValueCountMap = mutable.Map[Double, Double]()
+ var partNumSamples = 0.0
+ var unweightedNumSamples = 0.0
+ featureSamples.foreach { case (sampleWeight, feature) =>
+ partValueCountMap(feature) = partValueCountMap.getOrElse(feature, 0.0) + sampleWeight;
+ partNumSamples += sampleWeight;
+ unweightedNumSamples += 1.0
}
// Calculate the expected number of samples for finding splits
- val numSamples = (samplesFractionForFindSplits(metadata) * metadata.numExamples).toInt
+ val weightedNumSamples = samplesFractionForFindSplits(metadata) *
+ metadata.weightedNumExamples
// add expected zero value count and get complete statistics
- val valueCountMap: Map[Double, Int] = if (numSamples - partNumSamples > 0) {
- partValueCountMap.toMap + (0.0 -> (numSamples - partNumSamples))
+ val tolerance = Utils.EPSILON * unweightedNumSamples * unweightedNumSamples
+ val valueCountMap = if (weightedNumSamples - partNumSamples > tolerance) {
+ partValueCountMap + (0.0 -> (weightedNumSamples - partNumSamples))
} else {
- partValueCountMap.toMap
+ partValueCountMap
}
// sort distinct values
@@ -1031,7 +1072,7 @@ private[spark] object RandomForest extends Logging with Serializable {
.toArray
} else {
// stride between splits
- val stride: Double = numSamples.toDouble / (numSplits + 1)
+ val stride: Double = weightedNumSamples / (numSplits + 1)
logDebug("stride = " + stride)
// iterate `valueCount` to find splits
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala
index a6ac64a0463c..72440b2c57aa 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ml.tree.impl
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.tree.{ContinuousSplit, Split}
import org.apache.spark.rdd.RDD
@@ -36,10 +36,12 @@ import org.apache.spark.rdd.RDD
* @param label Label from LabeledPoint
* @param binnedFeatures Binned feature values.
* Same length as LabeledPoint.features, but values are bin indices.
+ * @param weight Sample weight for this TreePoint.
*/
-private[spark] class TreePoint(val label: Double, val binnedFeatures: Array[Int])
- extends Serializable {
-}
+private[spark] class TreePoint(
+ val label: Double,
+ val binnedFeatures: Array[Int],
+ val weight: Double) extends Serializable
private[spark] object TreePoint {
@@ -52,7 +54,7 @@ private[spark] object TreePoint {
* @return TreePoint dataset representation
*/
def convertToTreeRDD(
- input: RDD[LabeledPoint],
+ input: RDD[Instance],
splits: Array[Array[Split]],
metadata: DecisionTreeMetadata): RDD[TreePoint] = {
// Construct arrays for featureArity for efficiency in the inner loop.
@@ -82,18 +84,18 @@ private[spark] object TreePoint {
* for categorical features.
*/
private def labeledPointToTreePoint(
- labeledPoint: LabeledPoint,
+ instance: Instance,
thresholds: Array[Array[Double]],
featureArity: Array[Int]): TreePoint = {
- val numFeatures = labeledPoint.features.size
+ val numFeatures = instance.features.size
val arr = new Array[Int](numFeatures)
var featureIndex = 0
while (featureIndex < numFeatures) {
arr(featureIndex) =
- findBin(featureIndex, labeledPoint, featureArity(featureIndex), thresholds(featureIndex))
+ findBin(featureIndex, instance, featureArity(featureIndex), thresholds(featureIndex))
featureIndex += 1
}
- new TreePoint(labeledPoint.label, arr)
+ new TreePoint(instance.label, arr, instance.weight)
}
/**
@@ -106,10 +108,10 @@ private[spark] object TreePoint {
*/
private def findBin(
featureIndex: Int,
- labeledPoint: LabeledPoint,
+ instance: Instance,
featureArity: Int,
thresholds: Array[Double]): Int = {
- val featureValue = labeledPoint.features(featureIndex)
+ val featureValue = instance.features(featureIndex)
if (featureArity == 0) {
val idx = java.util.Arrays.binarySearch(thresholds, featureValue)
@@ -125,7 +127,7 @@ private[spark] object TreePoint {
s"DecisionTree given invalid data:" +
s" Feature $featureIndex is categorical with values in {0,...,${featureArity - 1}," +
s" but a data point gives it value $featureValue.\n" +
- " Bad data point: " + labeledPoint.toString)
+ s" Bad data point: $instance")
}
featureValue.toInt
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index 4aa4c3617e7f..51d5d5c58c57 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -282,6 +282,7 @@ private[ml] object DecisionTreeModelReadWrite {
*
* @param id Index used for tree reconstruction. Indices follow a pre-order traversal.
* @param impurityStats Stats array. Impurity type is stored in metadata.
+ * @param rawCount The unweighted number of samples falling in this node.
* @param gain Gain, or arbitrary value if leaf node.
* @param leftChild Left child index, or arbitrary value if leaf node.
* @param rightChild Right child index, or arbitrary value if leaf node.
@@ -292,6 +293,7 @@ private[ml] object DecisionTreeModelReadWrite {
prediction: Double,
impurity: Double,
impurityStats: Array[Double],
+ rawCount: Long,
gain: Double,
leftChild: Int,
rightChild: Int,
@@ -311,11 +313,12 @@ private[ml] object DecisionTreeModelReadWrite {
val (leftNodeData, leftIdx) = build(n.leftChild, id + 1)
val (rightNodeData, rightIdx) = build(n.rightChild, leftIdx + 1)
val thisNodeData = NodeData(id, n.prediction, n.impurity, n.impurityStats.stats,
- n.gain, leftNodeData.head.id, rightNodeData.head.id, SplitData(n.split))
+ n.impurityStats.rawCount, n.gain, leftNodeData.head.id, rightNodeData.head.id,
+ SplitData(n.split))
(thisNodeData +: (leftNodeData ++ rightNodeData), rightIdx)
case _: LeafNode =>
(Seq(NodeData(id, node.prediction, node.impurity, node.impurityStats.stats,
- -1.0, -1, -1, SplitData(-1, Array.empty[Double], -1))),
+ node.impurityStats.rawCount, -1.0, -1, -1, SplitData(-1, Array.empty[Double], -1))),
id)
}
}
@@ -360,7 +363,8 @@ private[ml] object DecisionTreeModelReadWrite {
// traversal, this guarantees that child nodes will be built before parent nodes.
val finalNodes = new Array[Node](nodes.length)
nodes.reverseIterator.foreach { case n: NodeData =>
- val impurityStats = ImpurityCalculator.getCalculator(impurityType, n.impurityStats)
+ val impurityStats =
+ ImpurityCalculator.getCalculator(impurityType, n.impurityStats, n.rawCount)
val node = if (n.leftChild != -1) {
val leftChild = finalNodes(n.leftChild)
val rightChild = finalNodes(n.rightChild)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index c06c68d44ae1..df01dc007882 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -37,7 +37,7 @@ import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
* Note: Marked as private and DeveloperApi since this may be made public in the future.
*/
private[ml] trait DecisionTreeParams extends PredictorParams
- with HasCheckpointInterval with HasSeed {
+ with HasCheckpointInterval with HasSeed with HasWeightCol {
/**
* Maximum depth of the tree (>= 0).
@@ -74,6 +74,21 @@ private[ml] trait DecisionTreeParams extends PredictorParams
" child to have fewer than minInstancesPerNode, the split will be discarded as invalid." +
" Should be >= 1.", ParamValidators.gtEq(1))
+ /**
+ * Minimum fraction of the weighted sample count that each child must have after split.
+ * If a split causes the fraction of the total weight in the left or right child to be less than
+ * minWeightFractionPerNode, the split will be discarded as invalid.
+ * Should be in the interval [0.0, 0.5).
+ * (default = 0.0)
+ * @group param
+ */
+ final val minWeightFractionPerNode: DoubleParam = new DoubleParam(this,
+ "minWeightFractionPerNode", "Minimum fraction of the weighted sample count that each child " +
+ "must have after split. If a split causes the fraction of the total weight in the left or " +
+ "right child to be less than minWeightFractionPerNode, the split will be discarded as " +
+ "invalid. Should be in interval [0.0, 0.5)",
+ ParamValidators.inRange(0.0, 0.5, lowerInclusive = true, upperInclusive = false))
+
/**
* Minimum information gain for a split to be considered at a tree node.
* Should be >= 0.0.
@@ -107,8 +122,9 @@ private[ml] trait DecisionTreeParams extends PredictorParams
" algorithm will cache node IDs for each instance. Caching can speed up training of deeper" +
" trees.")
- setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0,
- maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10)
+ setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1,
+ minWeightFractionPerNode -> 0.0, minInfoGain -> 0.0, maxMemoryInMB -> 256,
+ cacheNodeIds -> false, checkpointInterval -> 10)
/** @group getParam */
final def getMaxDepth: Int = $(maxDepth)
@@ -119,6 +135,9 @@ private[ml] trait DecisionTreeParams extends PredictorParams
/** @group getParam */
final def getMinInstancesPerNode: Int = $(minInstancesPerNode)
+ /** @group getParam */
+ final def getMinWeightFractionPerNode: Double = $(minWeightFractionPerNode)
+
/** @group getParam */
final def getMinInfoGain: Double = $(minInfoGain)
@@ -143,6 +162,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams
strategy.maxMemoryInMB = getMaxMemoryInMB
strategy.minInfoGain = getMinInfoGain
strategy.minInstancesPerNode = getMinInstancesPerNode
+ strategy.minWeightFractionPerNode = getMinWeightFractionPerNode
strategy.useNodeIdCache = getCacheNodeIds
strategy.numClasses = numClasses
strategy.categoricalFeaturesInfo = categoricalFeatures
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index a8c5286f3dc1..94224be80752 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -23,6 +23,7 @@ import scala.util.Try
import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
+import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.tree.{DecisionTreeModel => NewDTModel, TreeEnsembleParams => NewRFParams}
import org.apache.spark.ml.tree.impl.{RandomForest => NewRandomForest}
import org.apache.spark.mllib.regression.LabeledPoint
@@ -91,8 +92,8 @@ private class RandomForest (
* @return RandomForestModel that can be used for prediction.
*/
def run(input: RDD[LabeledPoint]): RandomForestModel = {
- val trees: Array[NewDTModel] = NewRandomForest.run(input.map(_.asML), strategy, numTrees,
- featureSubsetStrategy, seed.toLong, None)
+ val trees: Array[NewDTModel] =
+ NewRandomForest.run(input, strategy, numTrees, featureSubsetStrategy, seed.toLong)
new RandomForestModel(strategy.algo, trees.map(_.toOld))
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index 58e8f5be7b9f..d9dcb8001340 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -80,7 +80,8 @@ class Strategy @Since("1.3.0") (
@Since("1.0.0") @BeanProperty var maxMemoryInMB: Int = 256,
@Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1,
@Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false,
- @Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10) extends Serializable {
+ @Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10,
+ @Since("3.0.0") @BeanProperty var minWeightFractionPerNode: Double = 0.0) extends Serializable {
/**
*/
@@ -96,6 +97,31 @@ class Strategy @Since("1.3.0") (
isMulticlassClassification && (categoricalFeaturesInfo.size > 0)
}
+ // scalastyle:off argcount
+ /**
+ * Backwards compatible constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]]
+ */
+ @Since("1.0.0")
+ def this(
+ algo: Algo,
+ impurity: Impurity,
+ maxDepth: Int,
+ numClasses: Int,
+ maxBins: Int,
+ quantileCalculationStrategy: QuantileStrategy,
+ categoricalFeaturesInfo: Map[Int, Int],
+ minInstancesPerNode: Int,
+ minInfoGain: Double,
+ maxMemoryInMB: Int,
+ subsamplingRate: Double,
+ useNodeIdCache: Boolean,
+ checkpointInterval: Int) {
+ this(algo, impurity, maxDepth, numClasses, maxBins, quantileCalculationStrategy,
+ categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, maxMemoryInMB,
+ subsamplingRate, useNodeIdCache, checkpointInterval, 0.0)
+ }
+ // scalastyle:on argcount
+
/**
* Java-friendly constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]]
*/
@@ -108,7 +134,8 @@ class Strategy @Since("1.3.0") (
maxBins: Int,
categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]) {
this(algo, impurity, maxDepth, numClasses, maxBins, Sort,
- categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap)
+ categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
+ minWeightFractionPerNode = 0.0)
}
/**
@@ -171,8 +198,9 @@ class Strategy @Since("1.3.0") (
@Since("1.2.0")
def copy: Strategy = {
new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
- quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain,
- maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointInterval)
+ quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode,
+ minInfoGain, maxMemoryInMB, subsamplingRate, useNodeIdCache,
+ checkpointInterval, minWeightFractionPerNode)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index d4448da9eef5..f01a98e74886 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -83,23 +83,29 @@ object Entropy extends Impurity {
* @param numClasses Number of classes for label.
*/
private[spark] class EntropyAggregator(numClasses: Int)
- extends ImpurityAggregator(numClasses) with Serializable {
+ extends ImpurityAggregator(numClasses + 1) with Serializable {
/**
* Update stats for one (node, feature, bin) with the given label.
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
* @param offset Start index of stats for this (node, feature, bin).
*/
- def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = {
- if (label >= statsSize) {
+ def update(
+ allStats: Array[Double],
+ offset: Int,
+ label: Double,
+ numSamples: Int,
+ sampleWeight: Double): Unit = {
+ if (label >= numClasses) {
throw new IllegalArgumentException(s"EntropyAggregator given label $label" +
- s" but requires label < numClasses (= $statsSize).")
+ s" but requires label < numClasses (= ${numClasses}).")
}
if (label < 0) {
throw new IllegalArgumentException(s"EntropyAggregator given label $label" +
s"but requires label is non-negative.")
}
- allStats(offset + label.toInt) += instanceWeight
+ allStats(offset + label.toInt) += numSamples * sampleWeight
+ allStats(offset + statsSize - 1) += numSamples
}
/**
@@ -108,7 +114,8 @@ private[spark] class EntropyAggregator(numClasses: Int)
* @param offset Start index of stats for this (node, feature, bin).
*/
def getCalculator(allStats: Array[Double], offset: Int): EntropyCalculator = {
- new EntropyCalculator(allStats.view(offset, offset + statsSize).toArray)
+ new EntropyCalculator(allStats.view(offset, offset + statsSize - 1).toArray,
+ allStats(offset + statsSize - 1).toLong)
}
}
@@ -118,12 +125,13 @@ private[spark] class EntropyAggregator(numClasses: Int)
* (node, feature, bin).
* @param stats Array of sufficient statistics for a (node, feature, bin).
*/
-private[spark] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
+private[spark] class EntropyCalculator(stats: Array[Double], var rawCount: Long)
+ extends ImpurityCalculator(stats) {
/**
* Make a deep copy of this [[ImpurityCalculator]].
*/
- def copy: EntropyCalculator = new EntropyCalculator(stats.clone())
+ def copy: EntropyCalculator = new EntropyCalculator(stats.clone(), rawCount)
/**
* Calculate the impurity from the stored sufficient statistics.
@@ -131,9 +139,9 @@ private[spark] class EntropyCalculator(stats: Array[Double]) extends ImpurityCal
def calculate(): Double = Entropy.calculate(stats, stats.sum)
/**
- * Number of data points accounted for in the sufficient statistics.
+ * Weighted number of data points accounted for in the sufficient statistics.
*/
- def count: Long = stats.sum.toLong
+ def count: Double = stats.sum
/**
* Prediction which should be made based on the sufficient statistics.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index c5e34ffa4f2e..913ffbbb2457 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -80,23 +80,29 @@ object Gini extends Impurity {
* @param numClasses Number of classes for label.
*/
private[spark] class GiniAggregator(numClasses: Int)
- extends ImpurityAggregator(numClasses) with Serializable {
+ extends ImpurityAggregator(numClasses + 1) with Serializable {
/**
* Update stats for one (node, feature, bin) with the given label.
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
* @param offset Start index of stats for this (node, feature, bin).
*/
- def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = {
- if (label >= statsSize) {
+ def update(
+ allStats: Array[Double],
+ offset: Int,
+ label: Double,
+ numSamples: Int,
+ sampleWeight: Double): Unit = {
+ if (label >= numClasses) {
throw new IllegalArgumentException(s"GiniAggregator given label $label" +
- s" but requires label < numClasses (= $statsSize).")
+ s" but requires label < numClasses (= ${numClasses}).")
}
if (label < 0) {
throw new IllegalArgumentException(s"GiniAggregator given label $label" +
- s"but requires label is non-negative.")
+ s"but requires label to be non-negative.")
}
- allStats(offset + label.toInt) += instanceWeight
+ allStats(offset + label.toInt) += numSamples * sampleWeight
+ allStats(offset + statsSize - 1) += numSamples
}
/**
@@ -105,7 +111,8 @@ private[spark] class GiniAggregator(numClasses: Int)
* @param offset Start index of stats for this (node, feature, bin).
*/
def getCalculator(allStats: Array[Double], offset: Int): GiniCalculator = {
- new GiniCalculator(allStats.view(offset, offset + statsSize).toArray)
+ new GiniCalculator(allStats.view(offset, offset + statsSize - 1).toArray,
+ allStats(offset + statsSize - 1).toLong)
}
}
@@ -115,12 +122,13 @@ private[spark] class GiniAggregator(numClasses: Int)
* (node, feature, bin).
* @param stats Array of sufficient statistics for a (node, feature, bin).
*/
-private[spark] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
+private[spark] class GiniCalculator(stats: Array[Double], var rawCount: Long)
+ extends ImpurityCalculator(stats) {
/**
* Make a deep copy of this [[ImpurityCalculator]].
*/
- def copy: GiniCalculator = new GiniCalculator(stats.clone())
+ def copy: GiniCalculator = new GiniCalculator(stats.clone(), rawCount)
/**
* Calculate the impurity from the stored sufficient statistics.
@@ -128,9 +136,9 @@ private[spark] class GiniCalculator(stats: Array[Double]) extends ImpurityCalcul
def calculate(): Double = Gini.calculate(stats, stats.sum)
/**
- * Number of data points accounted for in the sufficient statistics.
+ * Weighted number of data points accounted for in the sufficient statistics.
*/
- def count: Long = stats.sum.toLong
+ def count: Double = stats.sum
/**
* Prediction which should be made based on the sufficient statistics.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
index f151a6a01b65..491473490eba 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
@@ -81,7 +81,12 @@ private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Ser
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
* @param offset Start index of stats for this (node, feature, bin).
*/
- def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit
+ def update(
+ allStats: Array[Double],
+ offset: Int,
+ label: Double,
+ numSamples: Int,
+ sampleWeight: Double): Unit
/**
* Get an [[ImpurityCalculator]] for a (node, feature, bin).
@@ -122,6 +127,7 @@ private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) exten
stats(i) += other.stats(i)
i += 1
}
+ rawCount += other.rawCount
this
}
@@ -139,13 +145,19 @@ private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) exten
stats(i) -= other.stats(i)
i += 1
}
+ rawCount -= other.rawCount
this
}
/**
- * Number of data points accounted for in the sufficient statistics.
+ * Weighted number of data points accounted for in the sufficient statistics.
*/
- def count: Long
+ def count: Double
+
+ /**
+ * Raw number of data points accounted for in the sufficient statistics.
+ */
+ var rawCount: Long
/**
* Prediction which should be made based on the sufficient statistics.
@@ -185,11 +197,14 @@ private[spark] object ImpurityCalculator {
* Create an [[ImpurityCalculator]] instance of the given impurity type and with
* the given stats.
*/
- def getCalculator(impurity: String, stats: Array[Double]): ImpurityCalculator = {
+ def getCalculator(
+ impurity: String,
+ stats: Array[Double],
+ rawCount: Long): ImpurityCalculator = {
impurity.toLowerCase(Locale.ROOT) match {
- case "gini" => new GiniCalculator(stats)
- case "entropy" => new EntropyCalculator(stats)
- case "variance" => new VarianceCalculator(stats)
+ case "gini" => new GiniCalculator(stats, rawCount)
+ case "entropy" => new EntropyCalculator(stats, rawCount)
+ case "variance" => new VarianceCalculator(stats, rawCount)
case _ =>
throw new IllegalArgumentException(
s"ImpurityCalculator builder did not recognize impurity type: $impurity")
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
index c9bf0db4de3c..a07b919271f7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -66,21 +66,32 @@ object Variance extends Impurity {
/**
* Class for updating views of a vector of sufficient statistics,
- * in order to compute impurity from a sample.
+ * in order to compute impurity from a sample. For variance, we track:
+ * - sum(w_i)
+ * - sum(w_i * y_i)
+ * - sum(w_i * y_i * y_i)
+ * - count(y_i)
* Note: Instances of this class do not hold the data; they operate on views of the data.
*/
private[spark] class VarianceAggregator()
- extends ImpurityAggregator(statsSize = 3) with Serializable {
+ extends ImpurityAggregator(statsSize = 4) with Serializable {
/**
* Update stats for one (node, feature, bin) with the given label.
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
* @param offset Start index of stats for this (node, feature, bin).
*/
- def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = {
+ def update(
+ allStats: Array[Double],
+ offset: Int,
+ label: Double,
+ numSamples: Int,
+ sampleWeight: Double): Unit = {
+ val instanceWeight = numSamples * sampleWeight
allStats(offset) += instanceWeight
allStats(offset + 1) += instanceWeight * label
allStats(offset + 2) += instanceWeight * label * label
+ allStats(offset + 3) += numSamples
}
/**
@@ -89,7 +100,8 @@ private[spark] class VarianceAggregator()
* @param offset Start index of stats for this (node, feature, bin).
*/
def getCalculator(allStats: Array[Double], offset: Int): VarianceCalculator = {
- new VarianceCalculator(allStats.view(offset, offset + statsSize).toArray)
+ new VarianceCalculator(allStats.view(offset, offset + statsSize - 1).toArray,
+ allStats(offset + statsSize - 1).toLong)
}
}
@@ -99,7 +111,8 @@ private[spark] class VarianceAggregator()
* (node, feature, bin).
* @param stats Array of sufficient statistics for a (node, feature, bin).
*/
-private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
+private[spark] class VarianceCalculator(stats: Array[Double], var rawCount: Long)
+ extends ImpurityCalculator(stats) {
require(stats.length == 3,
s"VarianceCalculator requires sufficient statistics array stats to be of length 3," +
@@ -108,7 +121,7 @@ private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCa
/**
* Make a deep copy of this [[ImpurityCalculator]].
*/
- def copy: VarianceCalculator = new VarianceCalculator(stats.clone())
+ def copy: VarianceCalculator = new VarianceCalculator(stats.clone(), rawCount)
/**
* Calculate the impurity from the stored sufficient statistics.
@@ -116,9 +129,9 @@ private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCa
def calculate(): Double = Variance.calculate(stats(0), stats(1), stats(2))
/**
- * Number of data points accounted for in the sufficient statistics.
+ * Weighted number of data points accounted for in the sufficient statistics.
*/
- def count: Long = stats(0).toLong
+ def count: Double = stats(0)
/**
* Prediction which should be made based on the sufficient statistics.
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 2930f4900d50..433c78237024 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -42,6 +42,8 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest {
private var continuousDataPointsForMulticlassRDD: RDD[LabeledPoint] = _
private var categoricalDataPointsForMulticlassForOrderedFeaturesRDD: RDD[LabeledPoint] = _
+ private val seed = 42
+
override def beforeAll() {
super.beforeAll()
categoricalDataPointsRDD =
@@ -250,7 +252,7 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest {
MLTestingUtils.checkCopyAndUids(dt, newTree)
- testTransformer[(Vector, Double)](newData, newTree,
+ testTransformer[(Vector, Double, Double)](newData, newTree,
"prediction", "rawPrediction", "probability") {
case Row(pred: Double, rawPred: Vector, probPred: Vector) =>
assert(pred === rawPred.argmax,
@@ -327,6 +329,49 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest {
dt.fit(df)
}
+ test("training with sample weights") {
+ val df = {
+ val nPoints = 100
+ val coefficients = Array(
+ -0.57997, 0.912083, -0.371077,
+ -0.16624, -0.84355, -0.048509)
+
+ val xMean = Array(5.843, 3.057)
+ val xVariance = Array(0.6856, 0.1899)
+
+ val testData = LogisticRegressionSuite.generateMultinomialLogisticInput(
+ coefficients, xMean, xVariance, addIntercept = true, nPoints, seed)
+
+ sc.parallelize(testData, 4).toDF()
+ }
+ val numClasses = 3
+ val predEquals = (x: Double, y: Double) => x == y
+ // (impurity, maxDepth)
+ val testParams = Seq(
+ ("gini", 10),
+ ("entropy", 10),
+ ("gini", 5)
+ )
+ for ((impurity, maxDepth) <- testParams) {
+ val estimator = new DecisionTreeClassifier()
+ .setMaxDepth(maxDepth)
+ .setSeed(seed)
+ .setMinWeightFractionPerNode(0.049)
+ .setImpurity(impurity)
+
+ MLTestingUtils.testArbitrarilyScaledWeights[DecisionTreeClassificationModel,
+ DecisionTreeClassifier](df.as[LabeledPoint], estimator,
+ MLTestingUtils.modelPredictionEquals(df, predEquals, 0.7))
+ MLTestingUtils.testOutliersWithSmallWeights[DecisionTreeClassificationModel,
+ DecisionTreeClassifier](df.as[LabeledPoint], estimator,
+ numClasses, MLTestingUtils.modelPredictionEquals(df, predEquals, 0.8),
+ outlierRatio = 2)
+ MLTestingUtils.testOversamplingVsWeighting[DecisionTreeClassificationModel,
+ DecisionTreeClassifier](df.as[LabeledPoint], estimator,
+ MLTestingUtils.modelPredictionEquals(df, predEquals, 0.7), seed)
+ }
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
@@ -350,7 +395,6 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest {
TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 2)
testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings,
allParamSettings, checkModelData)
-
// Continuous splits with tree depth 2
val continuousData: DataFrame =
TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index ba4a9cf08278..027583ffc60b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.tree.LeafNode
@@ -141,7 +141,7 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest {
MLTestingUtils.checkCopyAndUids(rf, model)
- testTransformer[(Vector, Double)](df, model, "prediction", "rawPrediction",
+ testTransformer[(Vector, Double, Double)](df, model, "prediction", "rawPrediction",
"probability") { case Row(pred: Double, rawPred: Vector, probPred: Vector) =>
assert(pred === rawPred.argmax,
s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.")
@@ -180,7 +180,6 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest {
/////////////////////////////////////////////////////////////////////////////
// Tests of feature importance
/////////////////////////////////////////////////////////////////////////////
-
test("Feature importance with toy data") {
val numClasses = 2
val rf = new RandomForestClassifier()
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 29a438396516..38bfa6626e3c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
@@ -26,6 +26,7 @@ import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
DecisionTreeSuite => OldDecisionTreeSuite}
+import org.apache.spark.mllib.util.LinearDataGenerator
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
@@ -35,11 +36,17 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest {
import testImplicits._
private var categoricalDataPointsRDD: RDD[LabeledPoint] = _
+ private var linearRegressionData: DataFrame = _
+
+ private val seed = 42
override def beforeAll() {
super.beforeAll()
categoricalDataPointsRDD =
sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints().map(_.asML))
+ linearRegressionData = sc.parallelize(LinearDataGenerator.generateLinearInput(
+ intercept = 6.3, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3),
+ xVariance = Array(0.7, 1.2), nPoints = 1000, seed, eps = 0.5), 2).map(_.asML).toDF()
}
/////////////////////////////////////////////////////////////////////////////
@@ -88,7 +95,7 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest {
val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0)
val model = dt.fit(df)
- testTransformer[(Vector, Double)](df, model, "features", "variance") {
+ testTransformer[(Vector, Double, Double)](df, model, "features", "variance") {
case Row(features: Vector, variance: Double) =>
val expectedVariance = model.rootNode.predictImpl(features).impurityStats.calculate()
assert(variance === expectedVariance,
@@ -101,7 +108,7 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest {
.setMaxBins(6)
.setSeed(0)
- testTransformerByGlobalCheckFunc[(Vector, Double)](varianceDF, dt.fit(varianceDF),
+ testTransformerByGlobalCheckFunc[(Vector, Double, Double)](varianceDF, dt.fit(varianceDF),
"variance") { case rows: Seq[Row] =>
val calculatedVariances = rows.map(_.getDouble(0))
@@ -159,6 +166,28 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest {
}
}
+ test("training with sample weights") {
+ val df = linearRegressionData
+ val numClasses = 0
+ val testParams = Seq(5, 10)
+ for (maxDepth <- testParams) {
+ val estimator = new DecisionTreeRegressor()
+ .setMaxDepth(maxDepth)
+ .setMinWeightFractionPerNode(0.05)
+ .setSeed(123)
+ MLTestingUtils.testArbitrarilyScaledWeights[DecisionTreeRegressionModel,
+ DecisionTreeRegressor](df.as[LabeledPoint], estimator,
+ MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.1, 0.99))
+ MLTestingUtils.testOutliersWithSmallWeights[DecisionTreeRegressionModel,
+ DecisionTreeRegressor](df.as[LabeledPoint], estimator, numClasses,
+ MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.1, 0.99),
+ outlierRatio = 2)
+ MLTestingUtils.testOversamplingVsWeighting[DecisionTreeRegressionModel,
+ DecisionTreeRegressor](df.as[LabeledPoint], estimator,
+ MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.01, 1.0), seed)
+ }
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 90ceb7dee38f..76532897dff2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -891,6 +891,7 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLRe
.setStandardization(standardization)
.setRegParam(regParam)
.setElasticNetParam(elasticNetParam)
+ .setSolver(solver)
MLTestingUtils.testArbitrarilyScaledWeights[LinearRegressionModel, LinearRegression](
datasetWithStrongNoise.as[LabeledPoint], estimator, modelEquals)
MLTestingUtils.testOutliersWithSmallWeights[LinearRegressionModel, LinearRegression](
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala
index 77ab3d8bb75f..63985482795b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.tree.impl
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.mllib.tree.EnsembleTestHelper
import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -26,12 +27,16 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
*/
class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext {
- test("BaggedPoint RDD: without subsampling") {
- val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
+ test("BaggedPoint RDD: without subsampling with weights") {
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000).map { lp =>
+ Instance(lp.label, 0.5, lp.features.asML)
+ }
val rdd = sc.parallelize(arr)
- val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false, 42)
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false,
+ (instance: Instance) => instance.weight * 4.0, seed = 42)
baggedRDD.collect().foreach { baggedPoint =>
- assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1)
+ assert(baggedPoint.subsampleCounts.size === 1 && baggedPoint.subsampleCounts(0) === 1)
+ assert(baggedPoint.sampleWeight === 2.0)
}
}
@@ -40,13 +45,17 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext {
val (expectedMean, expectedStddev) = (1.0, 1.0)
val seeds = Array(123, 5354, 230, 349867, 23987)
- val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000).map(_.asML)
val rdd = sc.parallelize(arr)
seeds.foreach { seed =>
- val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true, seed)
- val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true,
+ (_: LabeledPoint) => 2.0, seed)
+ val subsampleCounts: Array[Array[Double]] =
+ baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect()
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
expectedStddev, epsilon = 0.01)
+ // should ignore weight function for now
+ assert(baggedRDD.collect().forall(_.sampleWeight === 1.0))
}
}
@@ -59,8 +68,10 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext {
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
val rdd = sc.parallelize(arr)
seeds.foreach { seed =>
- val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true, seed)
- val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
+ val baggedRDD =
+ BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true, seed = seed)
+ val subsampleCounts: Array[Array[Double]] =
+ baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect()
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
expectedStddev, epsilon = 0.01)
}
@@ -71,13 +82,17 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext {
val (expectedMean, expectedStddev) = (1.0, 0)
val seeds = Array(123, 5354, 230, 349867, 23987)
- val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000).map(_.asML)
val rdd = sc.parallelize(arr)
seeds.foreach { seed =>
- val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false, seed)
- val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false,
+ (_: LabeledPoint) => 2.0, seed)
+ val subsampleCounts: Array[Array[Double]] =
+ baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect()
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
expectedStddev, epsilon = 0.01)
+ // should ignore weight function for now
+ assert(baggedRDD.collect().forall(_.sampleWeight === 1.0))
}
}
@@ -90,8 +105,10 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext {
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
val rdd = sc.parallelize(arr)
seeds.foreach { seed =>
- val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false, seed)
- val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false,
+ seed = seed)
+ val subsampleCounts: Array[Array[Double]] =
+ baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect()
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
expectedStddev, epsilon = 0.01)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index 5caa5117d575..b89cc6053b06 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -19,10 +19,11 @@ package org.apache.spark.ml.tree.impl
import scala.annotation.tailrec
import scala.collection.mutable
+import scala.language.implicitConversions
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.tree._
import org.apache.spark.ml.util.TestingUtils._
@@ -46,7 +47,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
/////////////////////////////////////////////////////////////////////////////
test("Binary classification with continuous features: split calculation") {
- val arr = OldDTSuite.generateOrderedLabeledPointsWithLabel1().map(_.asML)
+ val arr = OldDTSuite.generateOrderedLabeledPointsWithLabel1().map(_.asML.toInstance)
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2, 100)
@@ -58,7 +59,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
}
test("Binary classification with binary (ordered) categorical features: split calculation") {
- val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML)
+ val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML.toInstance)
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 2,
@@ -75,7 +76,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
test("Binary classification with 3-ary (ordered) categorical features," +
" with no samples for one category: split calculation") {
- val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML)
+ val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML.toInstance)
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 2,
@@ -93,12 +94,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
test("find splits for a continuous feature") {
// find splits for normal case
{
- val fakeMetadata = new DecisionTreeMetadata(1, 200000, 0, 0,
+ val fakeMetadata = new DecisionTreeMetadata(1, 200000, 200000.0, 0, 0,
Map(), Set(),
Array(6), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0, 0
+ 0, 0, 0.0, 0.0, 0, 0
)
- val featureSamples = Array.fill(10000)(math.random).filter(_ != 0.0)
+ val featureSamples = Array.fill(10000)((1.0, math.random)).filter(_._2 != 0.0)
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
assert(splits.length === 5)
assert(fakeMetadata.numSplits(0) === 5)
@@ -109,15 +110,16 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
// SPARK-16957: Use midpoints for split values.
{
- val fakeMetadata = new DecisionTreeMetadata(1, 8, 0, 0,
+ val fakeMetadata = new DecisionTreeMetadata(1, 8, 8.0, 0, 0,
Map(), Set(),
Array(3), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0, 0
+ 0, 0, 0.0, 0.0, 0, 0
)
// possibleSplits <= numSplits
{
- val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1).map(_.toDouble).filter(_ != 0.0)
+ val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1)
+ .map(x => (1.0, x.toDouble)).filter(_._2 != 0.0)
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
val expectedSplits = Array((0.0 + 1.0) / 2)
assert(splits === expectedSplits)
@@ -125,7 +127,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
// possibleSplits > numSplits
{
- val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3).map(_.toDouble).filter(_ != 0.0)
+ val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3)
+ .map(x => (1.0, x.toDouble)).filter(_._2 != 0.0)
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
val expectedSplits = Array((0.0 + 1.0) / 2, (2.0 + 3.0) / 2)
assert(splits === expectedSplits)
@@ -135,12 +138,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
// find splits should not return identical splits
// when there are not enough split candidates, reduce the number of splits in metadata
{
- val fakeMetadata = new DecisionTreeMetadata(1, 12, 0, 0,
+ val fakeMetadata = new DecisionTreeMetadata(1, 12, 12.0, 0, 0,
Map(), Set(),
Array(5), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0, 0
+ 0, 0, 0.0, 0.0, 0, 0
)
- val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3).map(_.toDouble)
+ val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3).map(x => (1.0, x.toDouble))
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
val expectedSplits = Array((1.0 + 2.0) / 2, (2.0 + 3.0) / 2)
assert(splits === expectedSplits)
@@ -150,13 +153,13 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
// find splits when most samples close to the minimum
{
- val fakeMetadata = new DecisionTreeMetadata(1, 18, 0, 0,
+ val fakeMetadata = new DecisionTreeMetadata(1, 18, 18.0, 0, 0,
Map(), Set(),
Array(3), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0, 0
+ 0, 0, 0.0, 0.0, 0, 0
)
- val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5)
- .map(_.toDouble)
+ val featureSamples =
+ Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(x => (1.0, x.toDouble))
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
val expectedSplits = Array((2.0 + 3.0) / 2, (3.0 + 4.0) / 2)
assert(splits === expectedSplits)
@@ -164,37 +167,55 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
// find splits when most samples close to the maximum
{
- val fakeMetadata = new DecisionTreeMetadata(1, 17, 0, 0,
+ val fakeMetadata = new DecisionTreeMetadata(1, 17, 17.0, 0, 0,
Map(), Set(),
Array(2), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0, 0
+ 0, 0, 0.0, 0.0, 0, 0
)
- val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2)
- .map(_.toDouble).filter(_ != 0.0)
+ val featureSamples =
+ Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(x => (1.0, x.toDouble))
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
val expectedSplits = Array((1.0 + 2.0) / 2)
assert(splits === expectedSplits)
}
- // find splits for constant feature
+ // find splits for arbitrarily scaled data
{
- val fakeMetadata = new DecisionTreeMetadata(1, 3, 0, 0,
+ val fakeMetadata = new DecisionTreeMetadata(1, 0, 0.0, 0, 0,
+ Map(), Set(),
+ Array(6), Gini, QuantileStrategy.Sort,
+ 0, 0, 0.0, 0.0, 0, 0
+ )
+ val featureSamplesUnitWeight = Array.fill(10)((1.0, math.random))
+ val featureSamplesSmallWeight = featureSamplesUnitWeight.map { case (w, x) => (w * 0.001, x)}
+ val featureSamplesLargeWeight = featureSamplesUnitWeight.map { case (w, x) => (w * 1000, x)}
+ val splitsUnitWeight = RandomForest
+ .findSplitsForContinuousFeature(featureSamplesUnitWeight, fakeMetadata, 0)
+ val splitsSmallWeight = RandomForest
+ .findSplitsForContinuousFeature(featureSamplesSmallWeight, fakeMetadata, 0)
+ val splitsLargeWeight = RandomForest
+ .findSplitsForContinuousFeature(featureSamplesLargeWeight, fakeMetadata, 0)
+ assert(splitsUnitWeight === splitsSmallWeight)
+ assert(splitsUnitWeight === splitsLargeWeight)
+ }
+
+ // find splits when most weight is close to the minimum
+ {
+ val fakeMetadata = new DecisionTreeMetadata(1, 0, 0.0, 0, 0,
Map(), Set(),
Array(3), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0, 0
+ 0, 0, 0.0, 0.0, 0, 0
)
- val featureSamples = Array(0, 0, 0).map(_.toDouble).filter(_ != 0.0)
- val featureSamplesEmpty = Array.empty[Double]
+ val featureSamples = Array((10, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6)).map {
+ case (w, x) => (w.toDouble, x.toDouble)
+ }
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
- assert(splits === Array.empty[Double])
- val splitsEmpty =
- RandomForest.findSplitsForContinuousFeature(featureSamplesEmpty, fakeMetadata, 0)
- assert(splitsEmpty === Array.empty[Double])
+ assert(splits === Array(1.5, 2.5, 3.5, 4.5, 5.5))
}
}
test("train with empty arrays") {
- val lp = LabeledPoint(1.0, Vectors.dense(Array.empty[Double]))
+ val lp = LabeledPoint(1.0, Vectors.dense(Array.empty[Double])).toInstance
val data = Array.fill(5)(lp)
val rdd = sc.parallelize(data)
@@ -209,8 +230,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
}
test("train with constant features") {
- val lp = LabeledPoint(1.0, Vectors.dense(0.0, 0.0, 0.0))
- val data = Array.fill(5)(lp)
+ val instance = LabeledPoint(1.0, Vectors.dense(0.0, 0.0, 0.0)).toInstance
+ val data = Array.fill(5)(instance)
val rdd = sc.parallelize(data)
val strategy = new OldStrategy(
OldAlgo.Classification,
@@ -222,7 +243,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
val Array(tree) = RandomForest.run(rdd, strategy, 1, "all", 42L, instr = None)
assert(tree.rootNode.impurity === -1.0)
assert(tree.depth === 0)
- assert(tree.rootNode.prediction === lp.label)
+ assert(tree.rootNode.prediction === instance.label)
// Test with no categorical features
val strategy2 = new OldStrategy(
@@ -233,11 +254,11 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
val Array(tree2) = RandomForest.run(rdd, strategy2, 1, "all", 42L, instr = None)
assert(tree2.rootNode.impurity === -1.0)
assert(tree2.depth === 0)
- assert(tree2.rootNode.prediction === lp.label)
+ assert(tree2.rootNode.prediction === instance.label)
}
test("Multiclass classification with unordered categorical features: split calculations") {
- val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML)
+ val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML.toInstance)
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new OldStrategy(
@@ -278,7 +299,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
}
test("Multiclass classification with ordered categorical features: split calculations") {
- val arr = OldDTSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures().map(_.asML)
+ val arr = OldDTSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
+ .map(_.asML.toInstance)
assert(arr.length === 3000)
val rdd = sc.parallelize(arr)
val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 100,
@@ -310,7 +332,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)),
LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)),
LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
- val input = sc.parallelize(arr)
+ val input = sc.parallelize(arr.map(_.toInstance))
val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 1,
numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
@@ -352,7 +374,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)),
LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)),
LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
- val input = sc.parallelize(arr)
+ val input = sc.parallelize(arr.map(_.toInstance))
val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 5,
numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
@@ -404,7 +426,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
LabeledPoint(0.0, Vectors.dense(2.0)),
LabeledPoint(0.0, Vectors.dense(2.0)),
LabeledPoint(1.0, Vectors.dense(2.0)))
- val input = sc.parallelize(arr)
+ val input = sc.parallelize(arr.map(_.toInstance))
// Must set maxBins s.t. the feature will be treated as an ordered categorical feature.
val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 1,
@@ -424,7 +446,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
}
test("Second level node building with vs. without groups") {
- val arr = OldDTSuite.generateOrderedLabeledPoints().map(_.asML)
+ val arr = OldDTSuite.generateOrderedLabeledPoints().map(_.asML.toInstance)
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
// For tree with 1 group
@@ -468,7 +490,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
def binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy: OldStrategy) {
val numFeatures = 50
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures, 1000)
- val rdd = sc.parallelize(arr).map(_.asML)
+ val rdd = sc.parallelize(arr).map(_.asML.toInstance)
// Select feature subset for top nodes. Return true if OK.
def checkFeatureSubsetStrategy(
@@ -581,16 +603,16 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
left2 parent
left right
*/
- val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0))
+ val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0), 6L)
val left = new LeafNode(0.0, leftImp.calculate(), leftImp)
- val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0))
+ val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0), 8L)
val right = new LeafNode(2.0, rightImp.calculate(), rightImp)
val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5))
val parentImp = parent.impurityStats
- val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0))
+ val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0), 8L)
val left2 = new LeafNode(0.0, left2Imp.calculate(), left2Imp)
val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0))
@@ -647,12 +669,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
// feature_0 = 0 improves the impurity measure, despite the prediction will always be 0
// in both branches.
val arr = Array(
- LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
- LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
- LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
- LabeledPoint(1.0, Vectors.dense(1.0, 0.0)),
- LabeledPoint(0.0, Vectors.dense(1.0, 0.0)),
- LabeledPoint(1.0, Vectors.dense(1.0, 1.0))
+ Instance(0.0, 1.0, Vectors.dense(0.0, 1.0)),
+ Instance(1.0, 1.0, Vectors.dense(0.0, 1.0)),
+ Instance(0.0, 1.0, Vectors.dense(0.0, 0.0)),
+ Instance(1.0, 1.0, Vectors.dense(1.0, 0.0)),
+ Instance(0.0, 1.0, Vectors.dense(1.0, 0.0)),
+ Instance(1.0, 1.0, Vectors.dense(1.0, 1.0))
)
val rdd = sc.parallelize(arr)
@@ -677,13 +699,13 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
// feature_1 = 1 improves the impurity measure, despite the prediction will always be 0.5
// in both branches.
val arr = Array(
- LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
- LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
- LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
- LabeledPoint(0.0, Vectors.dense(1.0, 0.0)),
- LabeledPoint(1.0, Vectors.dense(1.0, 1.0)),
- LabeledPoint(0.0, Vectors.dense(1.0, 1.0)),
- LabeledPoint(0.5, Vectors.dense(1.0, 1.0))
+ Instance(0.0, 1.0, Vectors.dense(0.0, 1.0)),
+ Instance(1.0, 1.0, Vectors.dense(0.0, 1.0)),
+ Instance(0.0, 1.0, Vectors.dense(0.0, 0.0)),
+ Instance(0.0, 1.0, Vectors.dense(1.0, 0.0)),
+ Instance(1.0, 1.0, Vectors.dense(1.0, 1.0)),
+ Instance(0.0, 1.0, Vectors.dense(1.0, 1.0)),
+ Instance(0.5, 1.0, Vectors.dense(1.0, 1.0))
)
val rdd = sc.parallelize(arr)
@@ -700,6 +722,56 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(unprunedTree.numNodes === 5)
assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.size)
}
+
+ test("weights at arbitrary scale") {
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(3, 10)
+ val rddWithUnitWeights = sc.parallelize(arr.map(_.asML.toInstance))
+ val rddWithSmallWeights = rddWithUnitWeights.map { inst =>
+ Instance(inst.label, 0.001, inst.features)
+ }
+ val rddWithBigWeights = rddWithUnitWeights.map { inst =>
+ Instance(inst.label, 1000, inst.features)
+ }
+ val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2)
+ val unitWeightTrees = RandomForest.run(rddWithUnitWeights, strategy, 3, "all", 42L, None)
+
+ val smallWeightTrees = RandomForest.run(rddWithSmallWeights, strategy, 3, "all", 42L, None)
+ unitWeightTrees.zip(smallWeightTrees).foreach { case (unitTree, smallWeightTree) =>
+ TreeTests.checkEqual(unitTree, smallWeightTree)
+ }
+
+ val bigWeightTrees = RandomForest.run(rddWithBigWeights, strategy, 3, "all", 42L, None)
+ unitWeightTrees.zip(bigWeightTrees).foreach { case (unitTree, bigWeightTree) =>
+ TreeTests.checkEqual(unitTree, bigWeightTree)
+ }
+ }
+
+ test("minWeightFraction and minInstancesPerNode") {
+ val data = Array(
+ Instance(0.0, 1.0, Vectors.dense(0.0)),
+ Instance(0.0, 1.0, Vectors.dense(0.0)),
+ Instance(0.0, 1.0, Vectors.dense(0.0)),
+ Instance(0.0, 1.0, Vectors.dense(0.0)),
+ Instance(1.0, 0.1, Vectors.dense(1.0))
+ )
+ val rdd = sc.parallelize(data)
+ val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2,
+ minWeightFractionPerNode = 0.5)
+ val Array(tree1) = RandomForest.run(rdd, strategy, 1, "all", 42L, None)
+ assert(tree1.depth === 0)
+
+ strategy.minWeightFractionPerNode = 0.0
+ val Array(tree2) = RandomForest.run(rdd, strategy, 1, "all", 42L, None)
+ assert(tree2.depth === 1)
+
+ strategy.minInstancesPerNode = 2
+ val Array(tree3) = RandomForest.run(rdd, strategy, 1, "all", 42L, None)
+ assert(tree3.depth === 0)
+
+ strategy.minInstancesPerNode = 1
+ val Array(tree4) = RandomForest.run(rdd, strategy, 1, "all", 42L, None)
+ assert(tree4.depth === 1)
+ }
}
private object RandomForestSuite {
@@ -717,7 +789,7 @@ private object RandomForestSuite {
else {
nodes.head match {
case i: InternalNode => getSumLeafCounters(i.leftChild :: i.rightChild :: nodes.tail, acc)
- case l: LeafNode => getSumLeafCounters(nodes.tail, acc + l.impurityStats.count)
+ case l: LeafNode => getSumLeafCounters(nodes.tail, acc + l.impurityStats.rawCount)
}
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreePointSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreePointSuite.scala
index f41abe48f2c5..ce473ebf52e9 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreePointSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreePointSuite.scala
@@ -27,7 +27,7 @@ class TreePointSuite extends SparkFunSuite {
val ser = new KryoSerializer(conf).newInstance()
- val point = new TreePoint(1.0, Array(1, 2, 3))
+ val point = new TreePoint(1.0, Array(1, 2, 3), 1.0)
val point2 = ser.deserialize[TreePoint](ser.serialize(point))
assert(point.label === point2.label)
assert(point.binnedFeatures === point2.binnedFeatures)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
index ae9794b87b08..f3096e28d3d4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
@@ -18,13 +18,15 @@
package org.apache.spark.ml.tree.impl
import scala.collection.JavaConverters._
+import scala.util.Random
import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.tree._
+import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SparkSession}
@@ -32,6 +34,7 @@ private[ml] object TreeTests extends SparkFunSuite {
/**
* Convert the given data to a DataFrame, and set the features and label metadata.
+ *
* @param data Dataset. Categorical features and labels must already have 0-based indices.
* This must be non-empty.
* @param categoricalFeatures Map: categorical feature index to number of distinct values
@@ -39,16 +42,22 @@ private[ml] object TreeTests extends SparkFunSuite {
* @return DataFrame with metadata
*/
def setMetadata(
- data: RDD[LabeledPoint],
+ data: RDD[_],
categoricalFeatures: Map[Int, Int],
numClasses: Int): DataFrame = {
+ val dataOfInstance: RDD[Instance] = data.map {
+ _ match {
+ case instance: Instance => instance
+ case labeledPoint: LabeledPoint => labeledPoint.toInstance
+ }
+ }
val spark = SparkSession.builder()
.sparkContext(data.sparkContext)
.getOrCreate()
import spark.implicits._
- val df = data.toDF()
- val numFeatures = data.first().features.size
+ val df = dataOfInstance.toDF()
+ val numFeatures = dataOfInstance.first().features.size
val featuresAttributes = Range(0, numFeatures).map { feature =>
if (categoricalFeatures.contains(feature)) {
NominalAttribute.defaultAttr.withIndex(feature).withNumValues(categoricalFeatures(feature))
@@ -64,7 +73,7 @@ private[ml] object TreeTests extends SparkFunSuite {
}
val labelMetadata = labelAttribute.toMetadata()
df.select(df("features").as("features", featuresMetadata),
- df("label").as("label", labelMetadata))
+ df("label").as("label", labelMetadata), df("weight"))
}
/**
@@ -80,6 +89,7 @@ private[ml] object TreeTests extends SparkFunSuite {
/**
* Set label metadata (particularly the number of classes) on a DataFrame.
+ *
* @param data Dataset. Categorical features and labels must already have 0-based indices.
* This must be non-empty.
* @param numClasses Number of classes label can take. If 0, mark as continuous.
@@ -124,8 +134,8 @@ private[ml] object TreeTests extends SparkFunSuite {
* make mistakes such as creating loops of Nodes.
*/
private def checkEqual(a: Node, b: Node): Unit = {
- assert(a.prediction === b.prediction)
- assert(a.impurity === b.impurity)
+ assert(a.prediction ~== b.prediction absTol 1e-8)
+ assert(a.impurity ~== b.impurity absTol 1e-8)
(a, b) match {
case (aye: InternalNode, bee: InternalNode) =>
assert(aye.split === bee.split)
@@ -156,6 +166,7 @@ private[ml] object TreeTests extends SparkFunSuite {
/**
* Helper method for constructing a tree for testing.
* Given left, right children, construct a parent node.
+ *
* @param split Split for parent node
* @return Parent node with children attached
*/
@@ -163,8 +174,8 @@ private[ml] object TreeTests extends SparkFunSuite {
val leftImp = left.impurityStats
val rightImp = right.impurityStats
val parentImp = leftImp.copy.add(rightImp)
- val leftWeight = leftImp.count / parentImp.count.toDouble
- val rightWeight = rightImp.count / parentImp.count.toDouble
+ val leftWeight = leftImp.count / parentImp.count
+ val rightWeight = rightImp.count / parentImp.count
val gain = parentImp.calculate() -
(leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate())
val pred = parentImp.predict
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
index 91a8b14625a8..2c73700ee962 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
@@ -23,7 +23,7 @@ import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasWeightCol}
+import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.recommendation.{ALS, ALSModel}
import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
@@ -205,8 +205,8 @@ object MLTestingUtils extends SparkFunSuite {
seed: Long): Unit = {
val (overSampledData, weightedData) = genEquivalentOversampledAndWeightedInstances(
data, seed)
- val weightedModel = estimator.set(estimator.weightCol, "weight").fit(weightedData)
val overSampledModel = estimator.set(estimator.weightCol, "").fit(overSampledData)
+ val weightedModel = estimator.set(estimator.weightCol, "weight").fit(weightedData)
modelEquals(weightedModel, overSampledModel)
}
@@ -228,7 +228,8 @@ object MLTestingUtils extends SparkFunSuite {
List.fill(outlierRatio)(Instance(outlierLabel, 0.0001, f)) ++ List(Instance(l, w, f))
}
val trueModel = estimator.set(estimator.weightCol, "").fit(data)
- val outlierModel = estimator.set(estimator.weightCol, "weight").fit(outlierDS)
+ val outlierModel = estimator.set(estimator.weightCol, "weight")
+ .fit(outlierDS)
modelEquals(trueModel, outlierModel)
}
@@ -241,7 +242,7 @@ object MLTestingUtils extends SparkFunSuite {
estimator: E with HasWeightCol,
modelEquals: (M, M) => Unit): Unit = {
estimator.set(estimator.weightCol, "weight")
- val models = Seq(0.001, 1.0, 1000.0).map { w =>
+ val models = Seq(0.01, 1.0, 1000.0).map { w =>
val df = data.withColumn("weight", lit(w))
estimator.fit(df)
}
@@ -268,4 +269,20 @@ object MLTestingUtils extends SparkFunSuite {
assert(newDatasetF.schema(featuresColName).dataType.equals(new ArrayType(FloatType, false)))
(newDataset, newDatasetD, newDatasetF)
}
+
+ def modelPredictionEquals[M <: PredictionModel[_, M]](
+ data: DataFrame,
+ compareFunc: (Double, Double) => Boolean,
+ fractionInTol: Double)(
+ model1: M,
+ model2: M): Unit = {
+ val pred1 = model1.transform(data).select(model1.getPredictionCol).collect()
+ val pred2 = model2.transform(data).select(model2.getPredictionCol).collect()
+ val inTol = pred1.zip(pred2).count { case (p1, p2) =>
+ val x = p1.getDouble(0)
+ val y = p2.getDouble(0)
+ compareFunc(x, y)
+ }
+ assert(inTol / pred1.length.toDouble >= fractionInTol)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 34bc303ac607..8378a599362a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -73,7 +73,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
- val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1))
@@ -100,7 +100,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
maxDepth = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2))
- val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1))
@@ -116,7 +116,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Gini, maxDepth = 3,
numClasses = 2, maxBins = 100)
- val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1))
@@ -133,7 +133,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Gini, maxDepth = 3,
numClasses = 2, maxBins = 100)
- val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1))
@@ -150,7 +150,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Entropy, maxDepth = 3,
numClasses = 2, maxBins = 100)
- val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1))
@@ -167,7 +167,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Entropy, maxDepth = 3,
numClasses = 2, maxBins = 100)
- val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1))
@@ -183,7 +183,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClasses = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
- val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
assert(strategy.isMulticlassClassification)
assert(metadata.isUnordered(featureIndex = 0))
assert(metadata.isUnordered(featureIndex = 1))
@@ -240,7 +240,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
numClasses = 3, maxBins = maxBins,
categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
assert(strategy.isMulticlassClassification)
- val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
assert(metadata.isUnordered(featureIndex = 0))
assert(metadata.isUnordered(featureIndex = 1))
@@ -288,7 +288,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClasses = 3, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3))
assert(strategy.isMulticlassClassification)
- val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
assert(metadata.isUnordered(featureIndex = 0))
val model = DecisionTree.train(rdd, strategy)
@@ -310,7 +310,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
numClasses = 3, maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
assert(strategy.isMulticlassClassification)
- val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy)
assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1))
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala
index d0f02dd966bd..078c6e6fff9f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala
@@ -18,23 +18,63 @@
package org.apache.spark.mllib.tree
import org.apache.spark.SparkFunSuite
-import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator}
+import org.apache.spark.ml.util.TestingUtils._
+import org.apache.spark.mllib.tree.impurity._
/**
* Test suites for `GiniAggregator` and `EntropyAggregator`.
*/
class ImpuritySuite extends SparkFunSuite {
+
+ private val seed = 42
+
test("Gini impurity does not support negative labels") {
val gini = new GiniAggregator(2)
intercept[IllegalArgumentException] {
- gini.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0)
+ gini.update(Array(0.0, 1.0, 2.0), 0, -1, 3, 0.0)
}
}
test("Entropy does not support negative labels") {
val entropy = new EntropyAggregator(2)
intercept[IllegalArgumentException] {
- entropy.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0)
+ entropy.update(Array(0.0, 1.0, 2.0), 0, -1, 3, 0.0)
+ }
+ }
+
+ test("Classification impurities are insensitive to scaling") {
+ val rng = new scala.util.Random(seed)
+ val weightedCounts = Array.fill(5)(rng.nextDouble())
+ val smallWeightedCounts = weightedCounts.map(_ * 0.0001)
+ val largeWeightedCounts = weightedCounts.map(_ * 10000)
+ Seq(Gini, Entropy).foreach { impurity =>
+ val impurity1 = impurity.calculate(weightedCounts, weightedCounts.sum)
+ assert(impurity.calculate(smallWeightedCounts, smallWeightedCounts.sum)
+ ~== impurity1 relTol 0.005)
+ assert(impurity.calculate(largeWeightedCounts, largeWeightedCounts.sum)
+ ~== impurity1 relTol 0.005)
}
}
+
+ test("Regression impurities are insensitive to scaling") {
+ def computeStats(samples: Seq[Double], weights: Seq[Double]): (Double, Double, Double) = {
+ samples.zip(weights).foldLeft((0.0, 0.0, 0.0)) { case ((wn, wy, wyy), (y, w)) =>
+ (wn + w, wy + w * y, wyy + w * y * y)
+ }
+ }
+ val rng = new scala.util.Random(seed)
+ val samples = Array.fill(10)(rng.nextDouble())
+ val _weights = Array.fill(10)(rng.nextDouble())
+ val smallWeights = _weights.map(_ * 0.0001)
+ val largeWeights = _weights.map(_ * 10000)
+ val (count, sum, sumSquared) = computeStats(samples, _weights)
+ Seq(Variance).foreach { impurity =>
+ val impurity1 = impurity.calculate(count, sum, sumSquared)
+ val (smallCount, smallSum, smallSumSquared) = computeStats(samples, smallWeights)
+ val (largeCount, largeSum, largeSumSquared) = computeStats(samples, largeWeights)
+ assert(impurity.calculate(smallCount, smallSum, smallSumSquared) ~== impurity1 relTol 0.005)
+ assert(impurity.calculate(largeCount, largeSum, largeSumSquared) ~== impurity1 relTol 0.005)
+ }
+ }
+
}