diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala b/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
index 2327917e2cad..94158bf5d6e3 100644
--- a/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
+++ b/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
@@ -31,7 +31,7 @@ object TestingUtils {
* Note that if x or y is extremely close to zero, i.e., smaller than Double.MinPositiveValue,
* the relative tolerance is meaningless, so the exception will be raised to warn users.
*/
- private def RelativeErrorComparison(x: Double, y: Double, eps: Double): Boolean = {
+ private[ml] def RelativeErrorComparison(x: Double, y: Double, eps: Double): Boolean = {
val absX = math.abs(x)
val absY = math.abs(y)
val diff = math.abs(x - y)
@@ -48,7 +48,7 @@ object TestingUtils {
/**
* Private helper function for comparing two values using absolute tolerance.
*/
- private def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = {
+ private[ml] def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = {
math.abs(x - y) < eps
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 9f60f0896ec5..6a9b3564d63f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -22,18 +22,21 @@ import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._
import org.apache.spark.annotation.Since
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.tree._
+import org.apache.spark.ml.tree.{DecisionTreeModel, Node, TreeClassifierParams}
import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._
import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.util._
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.Dataset
-
+import org.apache.spark.sql.{Dataset, Row}
+import org.apache.spark.sql.functions.{col, lit}
+import org.apache.spark.sql.types.DoubleType
/**
* Decision tree learning algorithm (http://en.wikipedia.org/wiki/Decision_tree_learning)
@@ -65,6 +68,9 @@ class DecisionTreeClassifier @Since("1.4.0") (
override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
/** @group setParam */
+ @Since("2.2.0")
+ def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value)
+
@Since("1.4.0")
override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
@@ -96,6 +102,16 @@ class DecisionTreeClassifier @Since("1.4.0") (
@Since("1.6.0")
override def setSeed(value: Long): this.type = set(seed, value)
+ /**
+ * Sets the value of param [[weightCol]].
+ * If this is not set or empty, we treat all instance weights as 1.0.
+ * Default is not set, so all instances have weight one.
+ *
+ * @group setParam
+ */
+ @Since("2.2.0")
+ def setWeightCol(value: String): this.type = set(weightCol, value)
+
override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
@@ -106,14 +122,23 @@ class DecisionTreeClassifier @Since("1.4.0") (
".train() called with non-matching numClasses and thresholds.length." +
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}
-
- val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
+ require(numClasses > 0, s"DecisionTreeClassifier (in extractLabeledPoints) found numClasses =" +
+ s" $numClasses, but requires numClasses > 0.")
+ val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
+ val instances =
+ dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
+ case Row(label: Double, weight: Double, features: Vector) =>
+ require(label % 1 == 0 && label >= 0 && label < numClasses, s"Classifier was given" +
+ s" dataset with invalid label $label. Labels must be integers in range" +
+ s" [0, $numClasses).")
+ Instance(label, weight, features)
+ }
val strategy = getOldStrategy(categoricalFeatures, numClasses)
- val instr = Instrumentation.create(this, oldDataset)
+ val instr = Instrumentation.create(this, instances)
instr.logParams(params: _*)
- val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
+ val trees = RandomForest.run(instances, strategy, numTrees = 1, featureSubsetStrategy = "all",
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
val m = trees.head.asInstanceOf[DecisionTreeClassificationModel]
@@ -124,11 +149,12 @@ class DecisionTreeClassifier @Since("1.4.0") (
/** (private[ml]) Train a decision tree on an RDD */
private[ml] def train(data: RDD[LabeledPoint],
oldStrategy: OldStrategy): DecisionTreeClassificationModel = {
- val instr = Instrumentation.create(this, data)
- instr.logParams(params: _*)
- val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all",
- seed = 0L, instr = Some(instr), parentUID = Some(uid))
+ val instances = data.map(_.toInstance(1.0))
+ val instr = Instrumentation.create(this, instances)
+ instr.logParams(params: _*)
+ val trees = RandomForest.run(instances, oldStrategy, numTrees = 1,
+ featureSubsetStrategy = "all", seed = 0L, instr = Some(instr), parentUID = Some(uid))
val m = trees.head.asInstanceOf[DecisionTreeClassificationModel]
instr.logSuccess(m)
@@ -176,6 +202,7 @@ class DecisionTreeClassificationModel private[ml] (
/**
* Construct a decision tree classification model.
+ *
* @param rootNode Root node of tree, with other nodes attached.
*/
private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) =
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index ce834f1d17e0..5674e4813492 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -21,19 +21,20 @@ import org.json4s.{DefaultFormats, JObject}
import org.json4s.JsonDSL._
import org.apache.spark.annotation.Since
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._
+import org.apache.spark.ml.tree.{RandomForestParams, TreeClassifierParams, TreeEnsembleModel}
import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
-import org.apache.spark.sql.functions._
-
+import org.apache.spark.sql.functions.{col, udf}
/**
* Random Forest learning algorithm for
@@ -126,20 +127,20 @@ class RandomForestClassifier @Since("1.4.0") (
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}
- val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
+ val instances: RDD[Instance] = extractLabeledPoints(dataset, numClasses).map(_.toInstance(1.0))
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
- val instr = Instrumentation.create(this, oldDataset)
+ val instr = Instrumentation.create(this, instances)
instr.logParams(labelCol, featuresCol, predictionCol, probabilityCol, rawPredictionCol,
impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain,
minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds, checkpointInterval)
val trees = RandomForest
- .run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
+ .run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
.map(_.asInstanceOf[DecisionTreeClassificationModel])
- val numFeatures = oldDataset.first().features.size
+ val numFeatures = instances.first().features.size
val m = new RandomForestClassificationModel(trees, numFeatures, numClasses)
instr.logSuccess(m)
m
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala
index cce3ca45ccd8..7e6e4c5a26e4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala
@@ -26,4 +26,4 @@ import org.apache.spark.ml.linalg.Vector
* @param weight The weight of this instance.
* @param features The vector of features for this data point.
*/
-private[ml] case class Instance(label: Double, weight: Double, features: Vector)
+private[spark] case class Instance(label: Double, weight: Double, features: Vector)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala
index c5d0ec1a8d35..a19f6a88968b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala
@@ -35,4 +35,9 @@ case class LabeledPoint(@Since("2.0.0") label: Double, @Since("2.0.0") features:
override def toString: String = {
s"($label,$features)"
}
+
+ private[spark] def toInstance(weight: Double): Instance = {
+ Instance(label, weight, features)
+ }
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 01c5cc1c7efa..8205714dff46 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -23,8 +23,9 @@ import org.json4s.JsonDSL._
import org.apache.spark.annotation.Since
import org.apache.spark.ml.{PredictionModel, Predictor}
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._
@@ -33,8 +34,10 @@ import org.apache.spark.ml.util._
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.functions.{col, lit}
+import org.apache.spark.sql.types.DoubleType
/**
@@ -64,6 +67,9 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
/** @group setParam */
+ @Since("2.2.0")
+ def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value)
+
@Since("1.4.0")
override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
@@ -99,16 +105,31 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
@Since("2.0.0")
def setVarianceCol(value: String): this.type = set(varianceCol, value)
+ /**
+ * Sets the value of param [[weightCol]].
+ * If this is not set or empty, we treat all instance weights as 1.0.
+ * Default is not set, so all instances have weight one.
+ *
+ * @group setParam
+ */
+ @Since("2.2.0")
+ def setWeightCol(value: String): this.type = set(weightCol, value)
+
override protected def train(dataset: Dataset[_]): DecisionTreeRegressionModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
- val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
+ val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
+ val instances =
+ dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
+ case Row(label: Double, weight: Double, features: Vector) =>
+ Instance(label, weight, features)
+ }
val strategy = getOldStrategy(categoricalFeatures)
- val instr = Instrumentation.create(this, oldDataset)
+ val instr = Instrumentation.create(this, instances)
instr.logParams(params: _*)
- val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
+ val trees = RandomForest.run(instances, strategy, numTrees = 1, featureSubsetStrategy = "all",
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
val m = trees.head.asInstanceOf[DecisionTreeRegressionModel]
@@ -122,8 +143,9 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
val instr = Instrumentation.create(this, data)
instr.logParams(params: _*)
- val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all",
- seed = $(seed), instr = Some(instr), parentUID = Some(uid))
+ val instances = data.map(_.toInstance(1.0))
+ val trees = RandomForest.run(instances, oldStrategy, numTrees = 1,
+ featureSubsetStrategy = "all", seed = $(seed), instr = Some(instr), parentUID = Some(uid))
val m = trees.head.asInstanceOf[DecisionTreeRegressionModel]
instr.logSuccess(m)
@@ -153,6 +175,7 @@ object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor
*
* Decision tree (Wikipedia) model for regression.
* It supports both continuous and categorical features.
+ *
* @param rootNode Root of the decision tree
*/
@Since("1.4.0")
@@ -171,6 +194,7 @@ class DecisionTreeRegressionModel private[ml] (
/**
* Construct a decision tree regression model.
+ *
* @param rootNode Root node of tree, with other nodes attached.
*/
private[ml] def this(rootNode: Node, numFeatures: Int) =
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 2f524a8c5784..0b6bcb9c8123 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -22,7 +22,6 @@ import org.json4s.JsonDSL._
import org.apache.spark.annotation.Since
import org.apache.spark.ml.{PredictionModel, Predictor}
-import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree._
@@ -31,10 +30,8 @@ import org.apache.spark.ml.util._
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
-import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
-import org.apache.spark.sql.functions._
-
+import org.apache.spark.sql.functions.{col, udf}
/**
* Random Forest
@@ -117,20 +114,20 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
override protected def train(dataset: Dataset[_]): RandomForestRegressionModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
- val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
+
+ val instances = extractLabeledPoints(dataset).map(_.toInstance(1.0))
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity)
- val instr = Instrumentation.create(this, oldDataset)
+ val instr = Instrumentation.create(this, instances)
instr.logParams(labelCol, featuresCol, predictionCol, impurity, numTrees,
featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain,
minInstancesPerNode, seed, subsamplingRate, cacheNodeIds, checkpointInterval)
val trees = RandomForest
- .run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
+ .run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
.map(_.asInstanceOf[DecisionTreeRegressionModel])
-
- val numFeatures = oldDataset.first().features.size
+ val numFeatures = instances.first().features.size
val m = new RandomForestRegressionModel(trees, numFeatures)
instr.logSuccess(m)
m
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala
index 4e372702f0c6..2bb7020232f3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala
@@ -33,13 +33,20 @@ import org.apache.spark.util.random.XORShiftRandom
* this datum has 1 copy, 0 copies, and 4 copies in the 3 subsamples, respectively.
*
* @param datum Data instance
- * @param subsampleWeights Weight of this instance in each subsampled dataset.
- *
- * TODO: This does not currently support (Double) weighted instances. Once MLlib has weighted
- * dataset support, update. (We store subsampleWeights as Double for this future extension.)
+ * @param subsampleCounts Number of samples of this instance in each subsampled dataset.
+ * @param sampleWeight The weight of this instance.
*/
-private[spark] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double])
- extends Serializable
+private[spark] class BaggedPoint[Datum](
+ val datum: Datum,
+ val subsampleCounts: Array[Int],
+ val sampleWeight: Double) extends Serializable {
+
+ /**
+ * Subsample counts weighted by the sample weight.
+ */
+ def weightedCounts: Array[Double] = subsampleCounts.map(_ * sampleWeight)
+
+}
private[spark] object BaggedPoint {
@@ -52,6 +59,7 @@ private[spark] object BaggedPoint {
* @param subsamplingRate Fraction of the training data used for learning decision tree.
* @param numSubsamples Number of subsamples of this RDD to take.
* @param withReplacement Sampling with/without replacement.
+ * @param extractSampleWeight A function to get the sample weight of each datum.
* @param seed Random seed.
* @return BaggedPoint dataset representation.
*/
@@ -60,12 +68,14 @@ private[spark] object BaggedPoint {
subsamplingRate: Double,
numSubsamples: Int,
withReplacement: Boolean,
+ extractSampleWeight: (Datum => Double) = (_: Datum) => 1.0,
seed: Long = Utils.random.nextLong()): RDD[BaggedPoint[Datum]] = {
+ // TODO: implement weighted bootstrapping
if (withReplacement) {
convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed)
} else {
if (numSubsamples == 1 && subsamplingRate == 1.0) {
- convertToBaggedRDDWithoutSampling(input)
+ convertToBaggedRDDWithoutSampling(input, extractSampleWeight)
} else {
convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed)
}
@@ -82,16 +92,16 @@ private[spark] object BaggedPoint {
val rng = new XORShiftRandom
rng.setSeed(seed + partitionIndex + 1)
instances.map { instance =>
- val subsampleWeights = new Array[Double](numSubsamples)
+ val subsampleCounts = new Array[Int](numSubsamples)
var subsampleIndex = 0
while (subsampleIndex < numSubsamples) {
val x = rng.nextDouble()
- subsampleWeights(subsampleIndex) = {
- if (x < subsamplingRate) 1.0 else 0.0
+ subsampleCounts(subsampleIndex) = {
+ if (x < subsamplingRate) 1 else 0
}
subsampleIndex += 1
}
- new BaggedPoint(instance, subsampleWeights)
+ new BaggedPoint(instance, subsampleCounts, 1.0)
}
}
}
@@ -106,20 +116,20 @@ private[spark] object BaggedPoint {
val poisson = new PoissonDistribution(subsample)
poisson.reseedRandomGenerator(seed + partitionIndex + 1)
instances.map { instance =>
- val subsampleWeights = new Array[Double](numSubsamples)
+ val subsampleCounts = new Array[Int](numSubsamples)
var subsampleIndex = 0
while (subsampleIndex < numSubsamples) {
- subsampleWeights(subsampleIndex) = poisson.sample()
+ subsampleCounts(subsampleIndex) = poisson.sample()
subsampleIndex += 1
}
- new BaggedPoint(instance, subsampleWeights)
+ new BaggedPoint(instance, subsampleCounts, 1.0)
}
}
}
private def convertToBaggedRDDWithoutSampling[Datum] (
- input: RDD[Datum]): RDD[BaggedPoint[Datum]] = {
- input.map(datum => new BaggedPoint(datum, Array(1.0)))
+ input: RDD[Datum],
+ extractSampleWeight: (Datum => Double)): RDD[BaggedPoint[Datum]] = {
+ input.map(datum => new BaggedPoint(datum, Array(1), extractSampleWeight(datum)))
}
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala
index 61091bb803e4..3124f4ee3c10 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala
@@ -104,16 +104,21 @@ private[spark] class DTStatsAggregator(
/**
* Update the stats for a given (feature, bin) for ordered features, using the given label.
*/
- def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = {
+ def update(
+ featureIndex: Int,
+ binIndex: Int,
+ label: Double,
+ numSamples: Int,
+ sampleWeight: Double): Unit = {
val i = featureOffsets(featureIndex) + binIndex * statsSize
- impurityAggregator.update(allStats, i, label, instanceWeight)
+ impurityAggregator.update(allStats, i, label, numSamples, sampleWeight)
}
/**
* Update the parent node stats using the given label.
*/
- def updateParent(label: Double, instanceWeight: Double): Unit = {
- impurityAggregator.update(parentStats, 0, label, instanceWeight)
+ def updateParent(label: Double, numSamples: Int, sampleWeight: Double): Unit = {
+ impurityAggregator.update(parentStats, 0, label, numSamples, sampleWeight)
}
/**
@@ -127,9 +132,10 @@ private[spark] class DTStatsAggregator(
featureOffset: Int,
binIndex: Int,
label: Double,
- instanceWeight: Double): Unit = {
+ numSamples: Int,
+ sampleWeight: Double): Unit = {
impurityAggregator.update(allStats, featureOffset + binIndex * statsSize,
- label, instanceWeight)
+ label, numSamples, sampleWeight)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
index 8a9dcb486b7b..b67cfff75293 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable
import scala.util.Try
import org.apache.spark.internal.Logging
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.tree.RandomForestParams
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
@@ -32,16 +32,20 @@ import org.apache.spark.rdd.RDD
/**
* Learning and dataset metadata for DecisionTree.
*
+ * @param weightedNumExamples Weighted count of samples in the tree.
* @param numClasses For classification: labels can take values {0, ..., numClasses - 1}.
* For regression: fixed at 0 (no meaning).
* @param maxBins Maximum number of bins, for all features.
* @param featureArity Map: categorical feature index to arity.
* I.e., the feature takes values in {0, ..., arity - 1}.
* @param numBins Number of bins for each feature.
+ * @param minWeightFractionPerNode The minimum fraction of the total sample weight that must be
+ * present in a leaf node in order to be considered a valid split.
*/
private[spark] class DecisionTreeMetadata(
val numFeatures: Int,
val numExamples: Long,
+ val weightedNumExamples: Double,
val numClasses: Int,
val maxBins: Int,
val featureArity: Map[Int, Int],
@@ -51,6 +55,7 @@ private[spark] class DecisionTreeMetadata(
val quantileStrategy: QuantileStrategy,
val maxDepth: Int,
val minInstancesPerNode: Int,
+ val minWeightFractionPerNode: Double,
val minInfoGain: Double,
val numTrees: Int,
val numFeaturesPerNode: Int) extends Serializable {
@@ -67,6 +72,8 @@ private[spark] class DecisionTreeMetadata(
def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex)
+ def minWeightPerNode: Double = minWeightFractionPerNode * weightedNumExamples
+
/**
* Number of splits for the given feature.
* For unordered features, there is 1 bin per split.
@@ -104,7 +111,7 @@ private[spark] object DecisionTreeMetadata extends Logging {
* as well as the number of splits and bins for each feature.
*/
def buildMetadata(
- input: RDD[LabeledPoint],
+ input: RDD[Instance],
strategy: Strategy,
numTrees: Int,
featureSubsetStrategy: String): DecisionTreeMetadata = {
@@ -115,7 +122,10 @@ private[spark] object DecisionTreeMetadata extends Logging {
}
require(numFeatures > 0, s"DecisionTree requires number of features > 0, " +
s"but was given an empty features vector")
- val numExamples = input.count()
+ val (numExamples, weightSum) = input.aggregate((0L, 0.0))(
+ (acc, x) => (acc._1 + 1L, acc._2 + x.weight),
+ (acc1, acc2) => (acc1._1 + acc2._1, acc1._2 + acc2._2))
+
val numClasses = strategy.algo match {
case Classification => strategy.numClasses
case Regression => 0
@@ -206,17 +216,18 @@ private[spark] object DecisionTreeMetadata extends Logging {
}
}
- new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
- strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
+ new DecisionTreeMetadata(numFeatures, numExamples, weightSum, numClasses,
+ numBins.max, strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth,
- strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode)
+ strategy.minInstancesPerNode, strategy.minWeightFractionPerNode, strategy.minInfoGain,
+ numTrees, numFeaturesPerNode)
}
/**
* Version of [[DecisionTreeMetadata#buildMetadata]] for DecisionTree.
*/
def buildMetadata(
- input: RDD[LabeledPoint],
+ input: RDD[Instance],
strategy: Strategy): DecisionTreeMetadata = {
buildMetadata(input, strategy, numTrees = 1, featureSubsetStrategy = "all")
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 008dd19c2498..80b4a97a82bd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -24,7 +24,8 @@ import scala.util.Random
import org.apache.spark.internal.Logging
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.Instance
+import org.apache.spark.ml.impl.Utils
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree._
import org.apache.spark.ml.util.Instrumentation
@@ -82,11 +83,11 @@ private[spark] object RandomForest extends Logging {
/**
* Train a random forest.
*
- * @param input Training data: RDD of `LabeledPoint`
+ * @param input Training data: RDD of [[org.apache.spark.ml.feature.Instance]]
* @return an unweighted set of trees
*/
def run(
- input: RDD[LabeledPoint],
+ input: RDD[Instance],
strategy: OldStrategy,
numTrees: Int,
featureSubsetStrategy: String,
@@ -100,9 +101,10 @@ private[spark] object RandomForest extends Logging {
timer.start("init")
- val retaggedInput = input.retag(classOf[LabeledPoint])
+ val retaggedInput = input.retag(classOf[Instance])
val metadata =
DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy)
+
instr match {
case Some(instrumentation) =>
instrumentation.logNumFeatures(metadata.numFeatures)
@@ -129,7 +131,8 @@ private[spark] object RandomForest extends Logging {
val withReplacement = numTrees > 1
val baggedInput = BaggedPoint
- .convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, withReplacement, seed)
+ .convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, withReplacement,
+ (tp: TreePoint) => tp.weight, seed = seed)
.persist(StorageLevel.MEMORY_AND_DISK)
// depth of the decision tree
@@ -250,19 +253,21 @@ private[spark] object RandomForest extends Logging {
* For unordered features, bins correspond to subsets of categories; either the left or right bin
* for each subset is updated.
*
- * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
- * each (feature, bin).
- * @param treePoint Data point being aggregated.
- * @param splits possible splits indexed (numFeatures)(numSplits)
- * @param unorderedFeatures Set of indices of unordered features.
- * @param instanceWeight Weight (importance) of instance in dataset.
+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
+ * each (feature, bin).
+ * @param treePoint Data point being aggregated.
+ * @param splits Possible splits indexed (numFeatures)(numSplits)
+ * @param unorderedFeatures Set of indices of unordered features.
+ * @param numSamples Number of times this instance occurs in the sample.
+ * @param sampleWeight Weight (importance) of instance in dataset.
*/
private def mixedBinSeqOp(
agg: DTStatsAggregator,
treePoint: TreePoint,
splits: Array[Array[Split]],
unorderedFeatures: Set[Int],
- instanceWeight: Double,
+ numSamples: Int,
+ sampleWeight: Double,
featuresForNode: Option[Array[Int]]): Unit = {
val numFeaturesPerNode = if (featuresForNode.nonEmpty) {
// Use subsampled features
@@ -289,14 +294,15 @@ private[spark] object RandomForest extends Logging {
var splitIndex = 0
while (splitIndex < numSplits) {
if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) {
- agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight)
+ agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, numSamples,
+ sampleWeight)
}
splitIndex += 1
}
} else {
// Ordered feature
val binIndex = treePoint.binnedFeatures(featureIndex)
- agg.update(featureIndexIdx, binIndex, treePoint.label, instanceWeight)
+ agg.update(featureIndexIdx, binIndex, treePoint.label, numSamples, sampleWeight)
}
featureIndexIdx += 1
}
@@ -310,12 +316,14 @@ private[spark] object RandomForest extends Logging {
* @param agg Array storing aggregate calculation, with a set of sufficient statistics for
* each (feature, bin).
* @param treePoint Data point being aggregated.
- * @param instanceWeight Weight (importance) of instance in dataset.
+ * @param numSamples Number of times this instance occurs in the sample.
+ * @param sampleWeight Weight (importance) of instance in dataset.
*/
private def orderedBinSeqOp(
agg: DTStatsAggregator,
treePoint: TreePoint,
- instanceWeight: Double,
+ numSamples: Int,
+ sampleWeight: Double,
featuresForNode: Option[Array[Int]]): Unit = {
val label = treePoint.label
@@ -325,7 +333,7 @@ private[spark] object RandomForest extends Logging {
var featureIndexIdx = 0
while (featureIndexIdx < featuresForNode.get.length) {
val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx))
- agg.update(featureIndexIdx, binIndex, label, instanceWeight)
+ agg.update(featureIndexIdx, binIndex, label, numSamples, sampleWeight)
featureIndexIdx += 1
}
} else {
@@ -334,7 +342,7 @@ private[spark] object RandomForest extends Logging {
var featureIndex = 0
while (featureIndex < numFeatures) {
val binIndex = treePoint.binnedFeatures(featureIndex)
- agg.update(featureIndex, binIndex, label, instanceWeight)
+ agg.update(featureIndex, binIndex, label, numSamples, sampleWeight)
featureIndex += 1
}
}
@@ -423,14 +431,16 @@ private[spark] object RandomForest extends Logging {
if (nodeInfo != null) {
val aggNodeIndex = nodeInfo.nodeIndexInGroup
val featuresForNode = nodeInfo.featureSubset
- val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
+ val numSamples = baggedPoint.subsampleCounts(treeIndex)
+ val sampleWeight = baggedPoint.sampleWeight
if (metadata.unorderedFeatures.isEmpty) {
- orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
+ orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, numSamples, sampleWeight,
+ featuresForNode)
} else {
mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
- metadata.unorderedFeatures, instanceWeight, featuresForNode)
+ metadata.unorderedFeatures, numSamples, sampleWeight, featuresForNode)
}
- agg(aggNodeIndex).updateParent(baggedPoint.datum.label, instanceWeight)
+ agg(aggNodeIndex).updateParent(baggedPoint.datum.label, numSamples, sampleWeight)
}
}
@@ -590,8 +600,8 @@ private[spark] object RandomForest extends Logging {
if (!isLeaf) {
node.split = Some(split)
val childIsLeaf = (LearningNode.indexToLevel(nodeIndex) + 1) == metadata.maxDepth
- val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0)
- val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0)
+ val leftChildIsLeaf = childIsLeaf || (math.abs(stats.leftImpurity) < Utils.EPSILON)
+ val rightChildIsLeaf = childIsLeaf || (math.abs(stats.rightImpurity) < Utils.EPSILON)
node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex),
leftChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator)))
node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex),
@@ -655,15 +665,20 @@ private[spark] object RandomForest extends Logging {
stats.impurity
}
+ val leftRawCount = leftImpurityCalculator.rawCount
+ val rightRawCount = rightImpurityCalculator.rawCount
val leftCount = leftImpurityCalculator.count
val rightCount = rightImpurityCalculator.count
val totalCount = leftCount + rightCount
- // If left child or right child doesn't satisfy minimum instances per node,
- // then this split is invalid, return invalid information gain stats.
- if ((leftCount < metadata.minInstancesPerNode) ||
- (rightCount < metadata.minInstancesPerNode)) {
+ val violatesMinInstancesPerNode = (leftRawCount < metadata.minInstancesPerNode) ||
+ (rightRawCount < metadata.minInstancesPerNode)
+ val violatesMinWeightPerNode = (leftCount < metadata.minWeightPerNode) ||
+ (rightCount < metadata.minWeightPerNode)
+ // If left child or right child doesn't satisfy minimum weight per node or minimum
+ // instances per node, then this split is invalid, return invalid information gain stats.
+ if (violatesMinInstancesPerNode || violatesMinWeightPerNode) {
return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
}
@@ -730,7 +745,8 @@ private[spark] object RandomForest extends Logging {
// Find best split.
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { case splitIdx =>
- val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
+ val leftChildStats =
+ binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
val rightChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
rightChildStats.subtract(leftChildStats)
@@ -872,14 +888,14 @@ private[spark] object RandomForest extends Logging {
* and for multiclass classification with a high-arity feature,
* there is one bin per category.
*
- * @param input Training data: RDD of [[LabeledPoint]]
+ * @param input Training data: RDD of [[Instance]]
* @param metadata Learning and dataset metadata
* @param seed random seed
* @return Splits, an Array of [[Split]]
* of size (numFeatures, numSplits)
*/
protected[tree] def findSplits(
- input: RDD[LabeledPoint],
+ input: RDD[Instance],
metadata: DecisionTreeMetadata,
seed: Long): Array[Array[Split]] = {
@@ -900,14 +916,14 @@ private[spark] object RandomForest extends Logging {
logDebug("fraction of data used for calculating quantiles = " + fraction)
input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt())
} else {
- input.sparkContext.emptyRDD[LabeledPoint]
+ input.sparkContext.emptyRDD[Instance]
}
findSplitsBySorting(sampledInput, metadata, continuousFeatures)
}
private def findSplitsBySorting(
- input: RDD[LabeledPoint],
+ input: RDD[Instance],
metadata: DecisionTreeMetadata,
continuousFeatures: IndexedSeq[Int]): Array[Array[Split]] = {
@@ -918,7 +934,7 @@ private[spark] object RandomForest extends Logging {
val numPartitions = math.min(continuousFeatures.length, input.partitions.length)
input
- .flatMap(point => continuousFeatures.map(idx => (idx, point.features(idx))))
+ .flatMap(point => continuousFeatures.map(idx => (idx, (point.weight, point.features(idx)))))
.groupByKey(numPartitions)
.map { case (idx, samples) =>
val thresholds = findSplitsForContinuousFeature(samples, metadata, idx)
@@ -982,7 +998,7 @@ private[spark] object RandomForest extends Logging {
* could be different from the specified `numSplits`.
* The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly.
*
- * @param featureSamples feature values of each sample
+ * @param featureSamples feature values and sample weights of each sample
* @param metadata decision tree metadata
* NOTE: `metadata.numbins` will be changed accordingly
* if there are not enough splits to be found
@@ -990,7 +1006,7 @@ private[spark] object RandomForest extends Logging {
* @return array of split thresholds
*/
private[tree] def findSplitsForContinuousFeature(
- featureSamples: Iterable[Double],
+ featureSamples: Iterable[(Double, Double)],
metadata: DecisionTreeMetadata,
featureIndex: Int): Array[Double] = {
require(metadata.isContinuous(featureIndex),
@@ -1002,9 +1018,9 @@ private[spark] object RandomForest extends Logging {
val numSplits = metadata.numSplits(featureIndex)
// get count for each distinct value
- val (valueCountMap, numSamples) = featureSamples.foldLeft((Map.empty[Double, Int], 0)) {
- case ((m, cnt), x) =>
- (m + ((x, m.getOrElse(x, 0) + 1)), cnt + 1)
+ val (valueCountMap, numSamples) = featureSamples.foldLeft((Map.empty[Double, Double], 0.0)) {
+ case ((m, cnt), (w, x)) =>
+ (m + ((x, m.getOrElse(x, 0.0) + w)), cnt + w)
}
// sort distinct values
val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala
index a6ac64a0463c..16b00299594a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ml.tree.impl
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.tree.{ContinuousSplit, Split}
import org.apache.spark.rdd.RDD
@@ -36,10 +36,12 @@ import org.apache.spark.rdd.RDD
* @param label Label from LabeledPoint
* @param binnedFeatures Binned feature values.
* Same length as LabeledPoint.features, but values are bin indices.
+ * @param weight Sample weight for this TreePoint.
*/
-private[spark] class TreePoint(val label: Double, val binnedFeatures: Array[Int])
- extends Serializable {
-}
+private[spark] class TreePoint(
+ val label: Double,
+ val binnedFeatures: Array[Int],
+ val weight: Double) extends Serializable
private[spark] object TreePoint {
@@ -52,7 +54,7 @@ private[spark] object TreePoint {
* @return TreePoint dataset representation
*/
def convertToTreeRDD(
- input: RDD[LabeledPoint],
+ input: RDD[Instance],
splits: Array[Array[Split]],
metadata: DecisionTreeMetadata): RDD[TreePoint] = {
// Construct arrays for featureArity for efficiency in the inner loop.
@@ -82,18 +84,18 @@ private[spark] object TreePoint {
* for categorical features.
*/
private def labeledPointToTreePoint(
- labeledPoint: LabeledPoint,
+ instance: Instance,
thresholds: Array[Array[Double]],
featureArity: Array[Int]): TreePoint = {
- val numFeatures = labeledPoint.features.size
+ val numFeatures = instance.features.size
val arr = new Array[Int](numFeatures)
var featureIndex = 0
while (featureIndex < numFeatures) {
arr(featureIndex) =
- findBin(featureIndex, labeledPoint, featureArity(featureIndex), thresholds(featureIndex))
+ findBin(featureIndex, instance, featureArity(featureIndex), thresholds(featureIndex))
featureIndex += 1
}
- new TreePoint(labeledPoint.label, arr)
+ new TreePoint(instance.label, arr, instance.weight)
}
/**
@@ -106,10 +108,10 @@ private[spark] object TreePoint {
*/
private def findBin(
featureIndex: Int,
- labeledPoint: LabeledPoint,
+ instance: Instance,
featureArity: Int,
thresholds: Array[Double]): Int = {
- val featureValue = labeledPoint.features(featureIndex)
+ val featureValue = instance.features(featureIndex)
if (featureArity == 0) {
val idx = java.util.Arrays.binarySearch(thresholds, featureValue)
@@ -125,7 +127,7 @@ private[spark] object TreePoint {
s"DecisionTree given invalid data:" +
s" Feature $featureIndex is categorical with values in {0,...,${featureArity - 1}," +
s" but a data point gives it value $featureValue.\n" +
- " Bad data point: " + labeledPoint.toString)
+ " Bad data point: " + instance.toString)
}
featureValue.toInt
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index 0d6e9034e5ce..820c42fa43b8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -219,7 +219,7 @@ private[ml] object TreeEnsembleModel {
importances.changeValue(feature, scaledGain, _ + scaledGain)
computeFeatureImportance(n.leftChild, importances)
computeFeatureImportance(n.rightChild, importances)
- case n: LeafNode =>
+ case _: LeafNode =>
// do nothing
}
}
@@ -282,6 +282,7 @@ private[ml] object DecisionTreeModelReadWrite {
*
* @param id Index used for tree reconstruction. Indices follow a pre-order traversal.
* @param impurityStats Stats array. Impurity type is stored in metadata.
+ * @param rawCount The unweighted number of samples falling in this node.
* @param gain Gain, or arbitrary value if leaf node.
* @param leftChild Left child index, or arbitrary value if leaf node.
* @param rightChild Right child index, or arbitrary value if leaf node.
@@ -292,6 +293,7 @@ private[ml] object DecisionTreeModelReadWrite {
prediction: Double,
impurity: Double,
impurityStats: Array[Double],
+ rawCount: Long,
gain: Double,
leftChild: Int,
rightChild: Int,
@@ -311,11 +313,12 @@ private[ml] object DecisionTreeModelReadWrite {
val (leftNodeData, leftIdx) = build(n.leftChild, id + 1)
val (rightNodeData, rightIdx) = build(n.rightChild, leftIdx + 1)
val thisNodeData = NodeData(id, n.prediction, n.impurity, n.impurityStats.stats,
- n.gain, leftNodeData.head.id, rightNodeData.head.id, SplitData(n.split))
+ n.impurityStats.rawCount, n.gain, leftNodeData.head.id, rightNodeData.head.id,
+ SplitData(n.split))
(thisNodeData +: (leftNodeData ++ rightNodeData), rightIdx)
case _: LeafNode =>
(Seq(NodeData(id, node.prediction, node.impurity, node.impurityStats.stats,
- -1.0, -1, -1, SplitData(-1, Array.empty[Double], -1))),
+ node.impurityStats.rawCount, -1.0, -1, -1, SplitData(-1, Array.empty[Double], -1))),
id)
}
}
@@ -360,7 +363,8 @@ private[ml] object DecisionTreeModelReadWrite {
// traversal, this guarantees that child nodes will be built before parent nodes.
val finalNodes = new Array[Node](nodes.length)
nodes.reverseIterator.foreach { case n: NodeData =>
- val impurityStats = ImpurityCalculator.getCalculator(impurityType, n.impurityStats)
+ val impurityStats =
+ ImpurityCalculator.getCalculator(impurityType, n.impurityStats, n.rawCount)
val node = if (n.leftChild != -1) {
val leftChild = finalNodes(n.leftChild)
val rightChild = finalNodes(n.rightChild)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index 5eb707dfe7bc..ff5955d40980 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -34,7 +34,7 @@ import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
* Note: Marked as private and DeveloperApi since this may be made public in the future.
*/
private[ml] trait DecisionTreeParams extends PredictorParams
- with HasCheckpointInterval with HasSeed {
+ with HasCheckpointInterval with HasSeed with HasWeightCol {
/**
* Maximum depth of the tree (>= 0).
@@ -71,6 +71,21 @@ private[ml] trait DecisionTreeParams extends PredictorParams
" child to have fewer than minInstancesPerNode, the split will be discarded as invalid." +
" Should be >= 1.", ParamValidators.gtEq(1))
+ /**
+ * Minimum fraction of the weighted sample count that each child must have after split.
+ * If a split causes the fraction of the total weight in the left or right child to be less than
+ * minWeightFractionPerNode, the split will be discarded as invalid.
+ * Should be in the interval [0.0, 0.5).
+ * (default = 0.0)
+ * @group param
+ */
+ final val minWeightFractionPerNode: DoubleParam = new DoubleParam(this,
+ "minWeightFractionPerNode", "Minimum fraction of the weighted sample count that each child " +
+ "must have after split. If a split causes the fraction of the total weight in the left or " +
+ "right child to be less than minWeightFractionPerNode, the split will be discarded as " +
+ "invalid. Should be in interval [0.0, 0.5)",
+ ParamValidators.inRange(0.0, 0.5, lowerInclusive = true, upperInclusive = false))
+
/**
* Minimum information gain for a split to be considered at a tree node.
* Should be >= 0.0.
@@ -104,8 +119,9 @@ private[ml] trait DecisionTreeParams extends PredictorParams
" algorithm will cache node IDs for each instance. Caching can speed up training of deeper" +
" trees.")
- setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0,
- maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10)
+ setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1,
+ minWeightFractionPerNode -> 0.0, minInfoGain -> 0.0, maxMemoryInMB -> 256,
+ cacheNodeIds -> false, checkpointInterval -> 10)
/**
* @deprecated This method is deprecated and will be removed in 2.2.0.
@@ -137,6 +153,9 @@ private[ml] trait DecisionTreeParams extends PredictorParams
/** @group getParam */
final def getMinInstancesPerNode: Int = $(minInstancesPerNode)
+ /** @group getParam */
+ final def getMinWeightFractionPerNode: Double = $(minWeightFractionPerNode)
+
/**
* @deprecated This method is deprecated and will be removed in 2.2.0.
* @group setParam
@@ -196,6 +215,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams
strategy.maxMemoryInMB = getMaxMemoryInMB
strategy.minInfoGain = getMinInfoGain
strategy.minInstancesPerNode = getMinInstancesPerNode
+ strategy.minWeightFractionPerNode = getMinWeightFractionPerNode
strategy.useNodeIdCache = getCacheNodeIds
strategy.numClasses = numClasses
strategy.categoricalFeaturesInfo = categoricalFeatures
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index d1331a57de27..9e1401588142 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -23,6 +23,7 @@ import scala.util.Try
import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
+import org.apache.spark.ml.feature.{Instance => NewInstance}
import org.apache.spark.ml.tree.{DecisionTreeModel => NewDTModel, RandomForestParams => NewRFParams}
import org.apache.spark.ml.tree.impl.{RandomForest => NewRandomForest}
import org.apache.spark.mllib.regression.LabeledPoint
@@ -91,8 +92,11 @@ private class RandomForest (
* @return RandomForestModel that can be used for prediction.
*/
def run(input: RDD[LabeledPoint]): RandomForestModel = {
- val trees: Array[NewDTModel] = NewRandomForest.run(input.map(_.asML), strategy, numTrees,
- featureSubsetStrategy, seed.toLong, None)
+ val instances = input.map { case LabeledPoint(label, features) =>
+ NewInstance(label, 1.0, features.asML)
+ }
+ val trees: Array[NewDTModel] =
+ NewRandomForest.run(instances, strategy, numTrees, featureSubsetStrategy, seed.toLong, None)
new RandomForestModel(strategy.algo, trees.map(_.toOld))
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index 58e8f5be7b9f..5806741b413b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -80,7 +80,8 @@ class Strategy @Since("1.3.0") (
@Since("1.0.0") @BeanProperty var maxMemoryInMB: Int = 256,
@Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1,
@Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false,
- @Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10) extends Serializable {
+ @Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10,
+ private[spark] var minWeightFractionPerNode: Double = 0.0) extends Serializable {
/**
*/
@@ -108,7 +109,8 @@ class Strategy @Since("1.3.0") (
maxBins: Int,
categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]) {
this(algo, impurity, maxDepth, numClasses, maxBins, Sort,
- categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap)
+ categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
+ minWeightFractionPerNode = 0.0)
}
/**
@@ -171,8 +173,9 @@ class Strategy @Since("1.3.0") (
@Since("1.2.0")
def copy: Strategy = {
new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
- quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain,
- maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointInterval)
+ quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode,
+ minInfoGain, maxMemoryInMB, subsamplingRate, useNodeIdCache,
+ checkpointInterval, minWeightFractionPerNode)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index d4448da9eef5..f01a98e74886 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -83,23 +83,29 @@ object Entropy extends Impurity {
* @param numClasses Number of classes for label.
*/
private[spark] class EntropyAggregator(numClasses: Int)
- extends ImpurityAggregator(numClasses) with Serializable {
+ extends ImpurityAggregator(numClasses + 1) with Serializable {
/**
* Update stats for one (node, feature, bin) with the given label.
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
* @param offset Start index of stats for this (node, feature, bin).
*/
- def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = {
- if (label >= statsSize) {
+ def update(
+ allStats: Array[Double],
+ offset: Int,
+ label: Double,
+ numSamples: Int,
+ sampleWeight: Double): Unit = {
+ if (label >= numClasses) {
throw new IllegalArgumentException(s"EntropyAggregator given label $label" +
- s" but requires label < numClasses (= $statsSize).")
+ s" but requires label < numClasses (= ${numClasses}).")
}
if (label < 0) {
throw new IllegalArgumentException(s"EntropyAggregator given label $label" +
s"but requires label is non-negative.")
}
- allStats(offset + label.toInt) += instanceWeight
+ allStats(offset + label.toInt) += numSamples * sampleWeight
+ allStats(offset + statsSize - 1) += numSamples
}
/**
@@ -108,7 +114,8 @@ private[spark] class EntropyAggregator(numClasses: Int)
* @param offset Start index of stats for this (node, feature, bin).
*/
def getCalculator(allStats: Array[Double], offset: Int): EntropyCalculator = {
- new EntropyCalculator(allStats.view(offset, offset + statsSize).toArray)
+ new EntropyCalculator(allStats.view(offset, offset + statsSize - 1).toArray,
+ allStats(offset + statsSize - 1).toLong)
}
}
@@ -118,12 +125,13 @@ private[spark] class EntropyAggregator(numClasses: Int)
* (node, feature, bin).
* @param stats Array of sufficient statistics for a (node, feature, bin).
*/
-private[spark] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
+private[spark] class EntropyCalculator(stats: Array[Double], var rawCount: Long)
+ extends ImpurityCalculator(stats) {
/**
* Make a deep copy of this [[ImpurityCalculator]].
*/
- def copy: EntropyCalculator = new EntropyCalculator(stats.clone())
+ def copy: EntropyCalculator = new EntropyCalculator(stats.clone(), rawCount)
/**
* Calculate the impurity from the stored sufficient statistics.
@@ -131,9 +139,9 @@ private[spark] class EntropyCalculator(stats: Array[Double]) extends ImpurityCal
def calculate(): Double = Entropy.calculate(stats, stats.sum)
/**
- * Number of data points accounted for in the sufficient statistics.
+ * Weighted number of data points accounted for in the sufficient statistics.
*/
- def count: Long = stats.sum.toLong
+ def count: Double = stats.sum
/**
* Prediction which should be made based on the sufficient statistics.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index c5e34ffa4f2e..913ffbbb2457 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -80,23 +80,29 @@ object Gini extends Impurity {
* @param numClasses Number of classes for label.
*/
private[spark] class GiniAggregator(numClasses: Int)
- extends ImpurityAggregator(numClasses) with Serializable {
+ extends ImpurityAggregator(numClasses + 1) with Serializable {
/**
* Update stats for one (node, feature, bin) with the given label.
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
* @param offset Start index of stats for this (node, feature, bin).
*/
- def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = {
- if (label >= statsSize) {
+ def update(
+ allStats: Array[Double],
+ offset: Int,
+ label: Double,
+ numSamples: Int,
+ sampleWeight: Double): Unit = {
+ if (label >= numClasses) {
throw new IllegalArgumentException(s"GiniAggregator given label $label" +
- s" but requires label < numClasses (= $statsSize).")
+ s" but requires label < numClasses (= ${numClasses}).")
}
if (label < 0) {
throw new IllegalArgumentException(s"GiniAggregator given label $label" +
- s"but requires label is non-negative.")
+ s"but requires label to be non-negative.")
}
- allStats(offset + label.toInt) += instanceWeight
+ allStats(offset + label.toInt) += numSamples * sampleWeight
+ allStats(offset + statsSize - 1) += numSamples
}
/**
@@ -105,7 +111,8 @@ private[spark] class GiniAggregator(numClasses: Int)
* @param offset Start index of stats for this (node, feature, bin).
*/
def getCalculator(allStats: Array[Double], offset: Int): GiniCalculator = {
- new GiniCalculator(allStats.view(offset, offset + statsSize).toArray)
+ new GiniCalculator(allStats.view(offset, offset + statsSize - 1).toArray,
+ allStats(offset + statsSize - 1).toLong)
}
}
@@ -115,12 +122,13 @@ private[spark] class GiniAggregator(numClasses: Int)
* (node, feature, bin).
* @param stats Array of sufficient statistics for a (node, feature, bin).
*/
-private[spark] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
+private[spark] class GiniCalculator(stats: Array[Double], var rawCount: Long)
+ extends ImpurityCalculator(stats) {
/**
* Make a deep copy of this [[ImpurityCalculator]].
*/
- def copy: GiniCalculator = new GiniCalculator(stats.clone())
+ def copy: GiniCalculator = new GiniCalculator(stats.clone(), rawCount)
/**
* Calculate the impurity from the stored sufficient statistics.
@@ -128,9 +136,9 @@ private[spark] class GiniCalculator(stats: Array[Double]) extends ImpurityCalcul
def calculate(): Double = Gini.calculate(stats, stats.sum)
/**
- * Number of data points accounted for in the sufficient statistics.
+ * Weighted number of data points accounted for in the sufficient statistics.
*/
- def count: Long = stats.sum.toLong
+ def count: Double = stats.sum
/**
* Prediction which should be made based on the sufficient statistics.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
index a5bdc2c6d2c9..6a814e658caa 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
@@ -79,7 +79,12 @@ private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Ser
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
* @param offset Start index of stats for this (node, feature, bin).
*/
- def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit
+ def update(
+ allStats: Array[Double],
+ offset: Int,
+ label: Double,
+ numSamples: Int,
+ sampleWeight: Double): Unit
/**
* Get an [[ImpurityCalculator]] for a (node, feature, bin).
@@ -120,6 +125,7 @@ private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) exten
stats(i) += other.stats(i)
i += 1
}
+ rawCount += other.rawCount
this
}
@@ -137,13 +143,19 @@ private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) exten
stats(i) -= other.stats(i)
i += 1
}
+ rawCount -= other.rawCount
this
}
/**
- * Number of data points accounted for in the sufficient statistics.
+ * Weighted number of data points accounted for in the sufficient statistics.
*/
- def count: Long
+ def count: Double
+
+ /**
+ * Raw number of data points accounted for in the sufficient statistics.
+ */
+ var rawCount: Long
/**
* Prediction which should be made based on the sufficient statistics.
@@ -183,11 +195,14 @@ private[spark] object ImpurityCalculator {
* Create an [[ImpurityCalculator]] instance of the given impurity type and with
* the given stats.
*/
- def getCalculator(impurity: String, stats: Array[Double]): ImpurityCalculator = {
+ def getCalculator(
+ impurity: String,
+ stats: Array[Double],
+ rawCount: Long): ImpurityCalculator = {
impurity match {
- case "gini" => new GiniCalculator(stats)
- case "entropy" => new EntropyCalculator(stats)
- case "variance" => new VarianceCalculator(stats)
+ case "gini" => new GiniCalculator(stats, rawCount)
+ case "entropy" => new EntropyCalculator(stats, rawCount)
+ case "variance" => new VarianceCalculator(stats, rawCount)
case _ =>
throw new IllegalArgumentException(
s"ImpurityCalculator builder did not recognize impurity type: $impurity")
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
index c9bf0db4de3c..a07b919271f7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -66,21 +66,32 @@ object Variance extends Impurity {
/**
* Class for updating views of a vector of sufficient statistics,
- * in order to compute impurity from a sample.
+ * in order to compute impurity from a sample. For variance, we track:
+ * - sum(w_i)
+ * - sum(w_i * y_i)
+ * - sum(w_i * y_i * y_i)
+ * - count(y_i)
* Note: Instances of this class do not hold the data; they operate on views of the data.
*/
private[spark] class VarianceAggregator()
- extends ImpurityAggregator(statsSize = 3) with Serializable {
+ extends ImpurityAggregator(statsSize = 4) with Serializable {
/**
* Update stats for one (node, feature, bin) with the given label.
* @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous.
* @param offset Start index of stats for this (node, feature, bin).
*/
- def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = {
+ def update(
+ allStats: Array[Double],
+ offset: Int,
+ label: Double,
+ numSamples: Int,
+ sampleWeight: Double): Unit = {
+ val instanceWeight = numSamples * sampleWeight
allStats(offset) += instanceWeight
allStats(offset + 1) += instanceWeight * label
allStats(offset + 2) += instanceWeight * label * label
+ allStats(offset + 3) += numSamples
}
/**
@@ -89,7 +100,8 @@ private[spark] class VarianceAggregator()
* @param offset Start index of stats for this (node, feature, bin).
*/
def getCalculator(allStats: Array[Double], offset: Int): VarianceCalculator = {
- new VarianceCalculator(allStats.view(offset, offset + statsSize).toArray)
+ new VarianceCalculator(allStats.view(offset, offset + statsSize - 1).toArray,
+ allStats(offset + statsSize - 1).toLong)
}
}
@@ -99,7 +111,8 @@ private[spark] class VarianceAggregator()
* (node, feature, bin).
* @param stats Array of sufficient statistics for a (node, feature, bin).
*/
-private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
+private[spark] class VarianceCalculator(stats: Array[Double], var rawCount: Long)
+ extends ImpurityCalculator(stats) {
require(stats.length == 3,
s"VarianceCalculator requires sufficient statistics array stats to be of length 3," +
@@ -108,7 +121,7 @@ private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCa
/**
* Make a deep copy of this [[ImpurityCalculator]].
*/
- def copy: VarianceCalculator = new VarianceCalculator(stats.clone())
+ def copy: VarianceCalculator = new VarianceCalculator(stats.clone(), rawCount)
/**
* Calculate the impurity from the stored sufficient statistics.
@@ -116,9 +129,9 @@ private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCa
def calculate(): Double = Variance.calculate(stats(0), stats(1), stats(2))
/**
- * Number of data points accounted for in the sufficient statistics.
+ * Weighted number of data points accounted for in the sufficient statistics.
*/
- def count: Long = stats(0).toLong
+ def count: Double = stats(0)
/**
* Prediction which should be made based on the sufficient statistics.
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index c711e7fa9dc6..a48050c4a25f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -21,8 +21,9 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode}
+import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode}
import org.apache.spark.ml.tree.impl.TreeTests
+import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite}
@@ -42,6 +43,9 @@ class DecisionTreeClassifierSuite
private var categoricalDataPointsForMulticlassRDD: RDD[LabeledPoint] = _
private var continuousDataPointsForMulticlassRDD: RDD[LabeledPoint] = _
private var categoricalDataPointsForMulticlassForOrderedFeaturesRDD: RDD[LabeledPoint] = _
+ private var linearMulticlassDataset: DataFrame = _
+
+ private val seed = 42
override def beforeAll() {
super.beforeAll()
@@ -58,6 +62,20 @@ class DecisionTreeClassifierSuite
categoricalDataPointsForMulticlassForOrderedFeaturesRDD = sc.parallelize(
OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures())
.map(_.asML)
+ linearMulticlassDataset = {
+ val nPoints = 100
+ val coefficients = Array(
+ -0.57997, 0.912083, -0.371077,
+ -0.16624, -0.84355, -0.048509)
+
+ val xMean = Array(5.843, 3.057)
+ val xVariance = Array(0.6856, 0.1899)
+
+ val testData = LogisticRegressionSuite.generateMultinomialLogisticInput(
+ coefficients, xMean, xVariance, addIntercept = true, nPoints, seed)
+
+ sc.parallelize(testData, 4).toDF()
+ }
}
test("params") {
@@ -246,7 +264,8 @@ class DecisionTreeClassifierSuite
val categoricalFeatures = Map(0 -> 3)
val numClasses = 3
- val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses)
+ val newData: DataFrame =
+ TreeTests.setMetadata(rdd.map(_.toInstance(1.0)), categoricalFeatures, numClasses)
val newTree = dt.fit(newData)
// copied model must have the same parent.
@@ -273,7 +292,7 @@ class DecisionTreeClassifierSuite
LabeledPoint(1, Vectors.dense(0, 3, 9)),
LabeledPoint(0, Vectors.dense(0, 2, 6))
))
- val df = TreeTests.setMetadata(data, Map(0 -> 1), 2)
+ val df = TreeTests.setMetadata(data.map(_.toInstance(1.0)), Map(0 -> 1), 2)
val dt = new DecisionTreeClassifier().setMaxDepth(3)
dt.fit(df)
}
@@ -295,7 +314,7 @@ class DecisionTreeClassifierSuite
LabeledPoint(0.0, Vectors.dense(2.0)),
LabeledPoint(1.0, Vectors.dense(2.0)))
val data = sc.parallelize(arr)
- val df = TreeTests.setMetadata(data, Map(0 -> 3), 2)
+ val df = TreeTests.setMetadata(data.map(_.toInstance(1.0)), Map(0 -> 3), 2)
// Must set maxBins s.t. the feature will be treated as an ordered categorical feature.
val dt = new DecisionTreeClassifier()
@@ -326,7 +345,7 @@ class DecisionTreeClassifierSuite
val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
val numFeatures = data.first().features.size
val categoricalFeatures = (0 to numFeatures).map(i => (i, 2)).toMap
- val df = TreeTests.setMetadata(data, categoricalFeatures, 2)
+ val df = TreeTests.setMetadata(data.map(_.toInstance(1.0)), categoricalFeatures, 2)
val model = dt.fit(df)
@@ -351,6 +370,36 @@ class DecisionTreeClassifierSuite
dt.fit(df)
}
+ test("training with sample weights") {
+ val df = linearMulticlassDataset
+ val numClasses = 3
+ val predEquals = (x: Double, y: Double) => x == y
+ // (impurity, maxDepth)
+ val testParams = Seq(
+ ("gini", 10),
+ ("entropy", 10),
+ ("gini", 5)
+ )
+ for ((impurity, maxDepth) <- testParams) {
+ val estimator = new DecisionTreeClassifier()
+ .setMaxDepth(maxDepth)
+ .setSeed(seed)
+ .setMinWeightFractionPerNode(0.049)
+ .setImpurity(impurity)
+
+ MLTestingUtils.testArbitrarilyScaledWeights[DecisionTreeClassificationModel,
+ DecisionTreeClassifier](df.as[LabeledPoint], estimator,
+ MLTestingUtils.modelPredictionEquals(df, predEquals, 0.9))
+ MLTestingUtils.testOutliersWithSmallWeights[DecisionTreeClassificationModel,
+ DecisionTreeClassifier](df.as[LabeledPoint], estimator,
+ numClasses, MLTestingUtils.modelPredictionEquals(df, predEquals, 0.8),
+ outlierRatio = 2)
+ MLTestingUtils.testOversamplingVsWeighting[DecisionTreeClassificationModel,
+ DecisionTreeClassifier](df.as[LabeledPoint], estimator,
+ MLTestingUtils.modelPredictionEquals(df, predEquals, 1.0), seed)
+ }
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
@@ -371,12 +420,12 @@ class DecisionTreeClassifierSuite
// Categorical splits with tree depth 2
val categoricalData: DataFrame =
- TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 2)
+ TreeTests.setMetadata(rdd.map(_.toInstance(1.0)), Map(0 -> 2, 1 -> 3), numClasses = 2)
testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings, checkModelData)
// Continuous splits with tree depth 2
val continuousData: DataFrame =
- TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
+ TreeTests.setMetadata(rdd.map(_.toInstance(1.0)), Map.empty[Int, Int], numClasses = 2)
testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings, checkModelData)
// Continuous splits with tree depth 0
@@ -399,7 +448,8 @@ private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite {
val numFeatures = data.first().features.size
val oldStrategy = dt.getOldStrategy(categoricalFeatures, numClasses)
val oldTree = OldDecisionTree.train(data.map(OldLabeledPoint.fromML), oldStrategy)
- val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
+ val newData: DataFrame =
+ TreeTests.setMetadata(data.map(_.toInstance(1.0)), categoricalFeatures, numClasses)
val newTree = dt.fit(newData)
// Use parent from newTree since this is not checked anyways.
val oldTreeAsNew = DecisionTreeClassificationModel.fromOld(
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 0598943c3d4b..b6dbcbff3eca 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -251,7 +251,8 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
sc.setCheckpointDir(path)
val categoricalFeatures = Map.empty[Int, Int]
- val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2)
+ val df: DataFrame =
+ TreeTests.setMetadata(data.map(_.toInstance(1.0)), categoricalFeatures, numClasses = 2)
val gbt = new GBTClassifier()
.setMaxDepth(2)
.setLossType("logistic")
@@ -346,7 +347,8 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
// In this data, feature 1 is very important.
val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
val categoricalFeatures = Map.empty[Int, Int]
- val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
+ val df: DataFrame =
+ TreeTests.setMetadata(data.map(_.toInstance(1.0)), categoricalFeatures, numClasses)
val importances = gbt.fit(df).featureImportances
val mostImportantFeature = importances.argmax
@@ -373,7 +375,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> "logistic")
val continuousData: DataFrame =
- TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
+ TreeTests.setMetadata(rdd.map(_.toInstance(1.0)), Map.empty[Int, Int], numClasses = 2)
testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData)
}
}
@@ -394,7 +396,8 @@ private object GBTClassifierSuite extends SparkFunSuite {
gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
val oldGBT = new OldGBT(oldBoostingStrategy, gbt.getSeed.toInt)
val oldModel = oldGBT.run(data.map(OldLabeledPoint.fromML))
- val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2)
+ val newData: DataFrame =
+ TreeTests.setMetadata(data.map(_.toInstance(1.0)), categoricalFeatures, numClasses = 2)
val newModel = gbt.fit(newData)
// Use parent from newTree since this is not checked anyways.
val oldModelAsNew = GBTClassificationModel.fromOld(
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
index ee2aefee7a6d..47c393bcdf7b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
@@ -134,7 +134,7 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
MLTestingUtils.testArbitrarilyScaledWeights[LinearSVCModel, LinearSVC](
dataset.as[LabeledPoint], estimator, modelEquals)
MLTestingUtils.testOutliersWithSmallWeights[LinearSVCModel, LinearSVC](
- dataset.as[LabeledPoint], estimator, 2, modelEquals)
+ dataset.as[LabeledPoint], estimator, 2, modelEquals, outlierRatio = 3)
MLTestingUtils.testOversamplingVsWeighting[LinearSVCModel, LinearSVC](
dataset.as[LabeledPoint], estimator, modelEquals, 42L)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 43547a4aafcb..d5d8fabaafa4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -1850,7 +1850,7 @@ class LogisticRegressionSuite
MLTestingUtils.testArbitrarilyScaledWeights[LogisticRegressionModel, LogisticRegression](
dataset.as[LabeledPoint], estimator, modelEquals)
MLTestingUtils.testOutliersWithSmallWeights[LogisticRegressionModel, LogisticRegression](
- dataset.as[LabeledPoint], estimator, numClasses, modelEquals)
+ dataset.as[LabeledPoint], estimator, numClasses, modelEquals, outlierRatio = 3)
MLTestingUtils.testOversamplingVsWeighting[LogisticRegressionModel, LogisticRegression](
dataset.as[LabeledPoint], estimator, modelEquals, seed)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 37d7991fe8dd..cc68562a6975 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -178,7 +178,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
MLTestingUtils.testArbitrarilyScaledWeights[NaiveBayesModel, NaiveBayes](
dataset.as[LabeledPoint], estimatorNoSmoothing, modelEquals)
MLTestingUtils.testOutliersWithSmallWeights[NaiveBayesModel, NaiveBayes](
- dataset.as[LabeledPoint], estimatorWithSmoothing, numClasses, modelEquals)
+ dataset.as[LabeledPoint], estimatorWithSmoothing, numClasses, modelEquals, outlierRatio = 3)
MLTestingUtils.testOversamplingVsWeighting[NaiveBayesModel, NaiveBayes](
dataset.as[LabeledPoint], estimatorWithSmoothing, modelEquals, seed)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index 44e1585ee514..f9d332277058 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -18,17 +18,17 @@
package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.tree.LeafNode
import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
@@ -129,7 +129,7 @@ class RandomForestClassifierSuite
}
test("predictRaw and predictProbability") {
- val rdd = orderedLabeledPoints5_20
+ val rdd = orderedLabeledPoints5_20.map(_.toInstance(1.0))
val rf = new RandomForestClassifier()
.setImpurity("Gini")
.setMaxDepth(3)
@@ -167,7 +167,6 @@ class RandomForestClassifierSuite
/////////////////////////////////////////////////////////////////////////////
// Tests of feature importance
/////////////////////////////////////////////////////////////////////////////
-
test("Feature importance with toy data") {
val numClasses = 2
val rf = new RandomForestClassifier()
@@ -179,7 +178,7 @@ class RandomForestClassifierSuite
.setSeed(123)
// In this data, feature 1 is very important.
- val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
+ val data: RDD[Instance] = TreeTests.featureImportanceData(sc).map(_.toInstance(1.0))
val categoricalFeatures = Map.empty[Int, Int]
val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
@@ -212,7 +211,7 @@ class RandomForestClassifierSuite
}
val rf = new RandomForestClassifier().setNumTrees(2)
- val rdd = TreeTests.getTreeReadWriteData(sc)
+ val rdd = TreeTests.getTreeReadWriteData(sc).map(_.toInstance(1.0))
val allParamSettings = TreeTests.allParamSettings ++ Map("impurity" -> "entropy")
@@ -239,7 +238,8 @@ private object RandomForestClassifierSuite extends SparkFunSuite {
val oldModel = OldRandomForest.trainClassifier(
data.map(OldLabeledPoint.fromML), oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy,
rf.getSeed.toInt)
- val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
+ val newData: DataFrame =
+ TreeTests.setMetadata(data.map(_.toInstance(1.0)), categoricalFeatures, numClasses)
val newModel = rf.fit(newData)
// Use parent from newTree since this is not checked anyways.
val oldModelAsNew = RandomForestClassificationModel.fromOld(
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 15fa26e8b527..1238b2fa8edb 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -18,15 +18,14 @@
package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
-import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
- DecisionTreeSuite => OldDecisionTreeSuite}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite}
+import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
@@ -34,13 +33,20 @@ class DecisionTreeRegressorSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
import DecisionTreeRegressorSuite.compareAPIs
+ import testImplicits._
private var categoricalDataPointsRDD: RDD[LabeledPoint] = _
+ private var linearRegressionData: DataFrame = _
+
+ private val seed = 42
override def beforeAll() {
super.beforeAll()
categoricalDataPointsRDD =
sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints().map(_.asML))
+ linearRegressionData = sc.parallelize(LinearDataGenerator.generateLinearInput(
+ intercept = 6.3, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3),
+ xVariance = Array(0.7, 1.2), nPoints = 1000, seed, eps = 0.5), 2).map(_.asML).toDF()
}
/////////////////////////////////////////////////////////////////////////////
@@ -68,7 +74,8 @@ class DecisionTreeRegressorSuite
test("copied model must have the same parent") {
val categoricalFeatures = Map(0 -> 2, 1 -> 2)
- val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0)
+ val df = TreeTests.setMetadata(categoricalDataPointsRDD.map(_.toInstance(1.0)),
+ categoricalFeatures, numClasses = 0)
val model = new DecisionTreeRegressor()
.setImpurity("variance")
.setMaxDepth(2)
@@ -85,7 +92,8 @@ class DecisionTreeRegressorSuite
.setVarianceCol("variance")
val categoricalFeatures = Map(0 -> 2, 1 -> 2)
- val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0)
+ val df = TreeTests.setMetadata(categoricalDataPointsRDD.map(_.toInstance(1.0)),
+ categoricalFeatures, numClasses = 0)
val model = dt.fit(df)
val predictions = model.transform(df)
@@ -98,7 +106,7 @@ class DecisionTreeRegressorSuite
s"Expected variance $expectedVariance but got $variance.")
}
- val varianceData: RDD[LabeledPoint] = TreeTests.varianceData(sc)
+ val varianceData: RDD[Instance] = TreeTests.varianceData(sc).map(_.toInstance(1.0))
val varianceDF = TreeTests.setMetadata(varianceData, Map.empty[Int, Int], 0)
dt.setMaxDepth(1)
.setMaxBins(6)
@@ -125,7 +133,7 @@ class DecisionTreeRegressorSuite
.setSeed(123)
// In this data, feature 1 is very important.
- val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
+ val data: RDD[Instance] = TreeTests.featureImportanceData(sc).map(_.toInstance(1.0))
val categoricalFeatures = Map.empty[Int, Int]
val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0)
@@ -146,6 +154,27 @@ class DecisionTreeRegressorSuite
}
}
+ test("training with sample weights") {
+ val df = linearRegressionData
+ val numClasses = 0
+ val testParams = Seq(5, 10)
+ for (maxDepth <- testParams) {
+ val estimator = new DecisionTreeRegressor()
+ .setMaxDepth(maxDepth)
+ .setMinWeightFractionPerNode(0.05)
+ MLTestingUtils.testArbitrarilyScaledWeights[DecisionTreeRegressionModel,
+ DecisionTreeRegressor](df.as[LabeledPoint], estimator,
+ MLTestingUtils.modelPredictionEquals(df, RelativeErrorComparison(_, _, 0.05), 0.9))
+ MLTestingUtils.testOutliersWithSmallWeights[DecisionTreeRegressionModel,
+ DecisionTreeRegressor](df.as[LabeledPoint], estimator, numClasses,
+ MLTestingUtils.modelPredictionEquals(df, RelativeErrorComparison(_, _, 0.1), 0.8),
+ outlierRatio = 2)
+ MLTestingUtils.testOversamplingVsWeighting[DecisionTreeRegressionModel,
+ DecisionTreeRegressor](df.as[LabeledPoint], estimator,
+ MLTestingUtils.modelPredictionEquals(df, RelativeErrorComparison(_, _, 0.01), 1.0), seed)
+ }
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
@@ -159,7 +188,7 @@ class DecisionTreeRegressorSuite
}
val dt = new DecisionTreeRegressor()
- val rdd = TreeTests.getTreeReadWriteData(sc)
+ val rdd = TreeTests.getTreeReadWriteData(sc).map(_.toInstance(1.0))
// Categorical splits with tree depth 2
val categoricalData: DataFrame =
@@ -192,7 +221,8 @@ private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite {
val numFeatures = data.first().features.size
val oldStrategy = dt.getOldStrategy(categoricalFeatures)
val oldTree = OldDecisionTree.train(data.map(OldLabeledPoint.fromML), oldStrategy)
- val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
+ val newData: DataFrame =
+ TreeTests.setMetadata(data.map(_.toInstance(1.0)), categoricalFeatures, numClasses = 0)
val newTree = dt.fit(newData)
// Use parent from newTree since this is not checked anyways.
val oldTreeAsNew = DecisionTreeRegressionModel.fromOld(
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index dcf3f9a1ea9b..e805e42649e6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -157,7 +157,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
// In this data, feature 1 is very important.
val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
val categoricalFeatures = Map.empty[Int, Int]
- val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0)
+ val df: DataFrame = TreeTests.setMetadata(data.map(_.toInstance(1.0)), categoricalFeatures, 0)
val importances = gbt.fit(df).featureImportances
val mostImportantFeature = importances.argmax
@@ -183,7 +183,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> "squared")
val continuousData: DataFrame =
- TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0)
+ TreeTests.setMetadata(rdd.map(_.toInstance(1.0)), Map.empty[Int, Int], numClasses = 0)
testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData)
}
}
@@ -203,7 +203,8 @@ private object GBTRegressorSuite extends SparkFunSuite {
val oldBoostingStrategy = gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
val oldGBT = new OldGBT(oldBoostingStrategy, gbt.getSeed.toInt)
val oldModel = oldGBT.run(data.map(OldLabeledPoint.fromML))
- val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
+ val newData: DataFrame =
+ TreeTests.setMetadata(data.map(_.toInstance(1.0)), categoricalFeatures, numClasses = 0)
val newModel = gbt.fit(newData)
// Use parent from newTree since this is not checked anyways.
val oldModelAsNew = GBTRegressionModel.fromOld(
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 584a1b272f6c..df3dbd0af3f3 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -839,10 +839,12 @@ class LinearRegressionSuite
.setStandardization(standardization)
.setRegParam(regParam)
.setElasticNetParam(elasticNetParam)
+ .setSolver(solver)
MLTestingUtils.testArbitrarilyScaledWeights[LinearRegressionModel, LinearRegression](
datasetWithStrongNoise.as[LabeledPoint], estimator, modelEquals)
MLTestingUtils.testOutliersWithSmallWeights[LinearRegressionModel, LinearRegression](
- datasetWithStrongNoise.as[LabeledPoint], estimator, numClasses, modelEquals)
+ datasetWithStrongNoise.as[LabeledPoint], estimator, numClasses, modelEquals,
+ outlierRatio = 3)
MLTestingUtils.testOversamplingVsWeighting[LinearRegressionModel, LinearRegression](
datasetWithStrongNoise.as[LabeledPoint], estimator, modelEquals, seed)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index c08335f9f84a..01207931f57b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -20,7 +20,8 @@ package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.tree.impl.TreeTests
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -28,6 +29,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
+
/**
* Test suite for [[RandomForestRegressor]].
*/
@@ -86,7 +88,7 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
// In this data, feature 1 is very important.
val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
val categoricalFeatures = Map.empty[Int, Int]
- val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0)
+ val df: DataFrame = TreeTests.setMetadata(data.map(_.toInstance(1.0)), categoricalFeatures, 0)
val model = rf.fit(df)
@@ -123,7 +125,7 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
val allParamSettings = TreeTests.allParamSettings ++ Map("impurity" -> "variance")
val continuousData: DataFrame =
- TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0)
+ TreeTests.setMetadata(rdd.map(_.toInstance(1.0)), Map.empty[Int, Int], numClasses = 0)
testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData)
}
}
@@ -143,7 +145,8 @@ private object RandomForestRegressorSuite extends SparkFunSuite {
rf.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, rf.getOldImpurity)
val oldModel = OldRandomForest.trainRegressor(data.map(OldLabeledPoint.fromML), oldStrategy,
rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt)
- val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
+ val newData: DataFrame = TreeTests.setMetadata(data.map(_.toInstance(1.0)),
+ categoricalFeatures, numClasses = 0)
val newModel = rf.fit(newData)
// Use parent from newTree since this is not checked anyways.
val oldModelAsNew = RandomForestRegressionModel.fromOld(
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala
index 77ab3d8bb75f..0b09577171c4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.tree.impl
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.mllib.tree.EnsembleTestHelper
import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -26,12 +27,16 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
*/
class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext {
- test("BaggedPoint RDD: without subsampling") {
- val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
+ test("BaggedPoint RDD: without subsampling with weights") {
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000).map { lp =>
+ Instance(lp.label, 0.5, lp.features.asML)
+ }
val rdd = sc.parallelize(arr)
- val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false, 42)
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false,
+ (instance: Instance) => instance.weight * 4.0, seed = 42)
baggedRDD.collect().foreach { baggedPoint =>
- assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1)
+ assert(baggedPoint.subsampleCounts.size == 1 && baggedPoint.subsampleCounts(0) == 1)
+ assert(baggedPoint.sampleWeight === 2.0)
}
}
@@ -40,13 +45,17 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext {
val (expectedMean, expectedStddev) = (1.0, 1.0)
val seeds = Array(123, 5354, 230, 349867, 23987)
- val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000).map(_.asML)
val rdd = sc.parallelize(arr)
seeds.foreach { seed =>
- val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true, seed)
- val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true,
+ (_: LabeledPoint) => 2.0, seed)
+ val subsampleCounts: Array[Array[Double]] =
+ baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect()
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
expectedStddev, epsilon = 0.01)
+ // should ignore weight function for now
+ assert(baggedRDD.collect().forall(_.sampleWeight === 1.0))
}
}
@@ -59,8 +68,10 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext {
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
val rdd = sc.parallelize(arr)
seeds.foreach { seed =>
- val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true, seed)
- val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
+ val baggedRDD =
+ BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true, seed = seed)
+ val subsampleCounts: Array[Array[Double]] =
+ baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect()
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
expectedStddev, epsilon = 0.01)
}
@@ -71,13 +82,17 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext {
val (expectedMean, expectedStddev) = (1.0, 0)
val seeds = Array(123, 5354, 230, 349867, 23987)
- val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000).map(_.asML)
val rdd = sc.parallelize(arr)
seeds.foreach { seed =>
- val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false, seed)
- val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false,
+ (_: LabeledPoint) => 2.0, seed)
+ val subsampleCounts: Array[Array[Double]] =
+ baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect()
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
expectedStddev, epsilon = 0.01)
+ // should ignore weight function for now
+ assert(baggedRDD.collect().forall(_.sampleWeight === 1.0))
}
}
@@ -90,8 +105,10 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext {
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
val rdd = sc.parallelize(arr)
seeds.foreach { seed =>
- val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false, seed)
- val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false,
+ seed = seed)
+ val subsampleCounts: Array[Array[Double]] =
+ baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect()
EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
expectedStddev, epsilon = 0.01)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index e1ab7c2d6520..775ca669d527 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -18,10 +18,11 @@
package org.apache.spark.ml.tree.impl
import scala.collection.mutable
+import scala.language.implicitConversions
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.tree._
import org.apache.spark.ml.util.TestingUtils._
@@ -43,7 +44,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
/////////////////////////////////////////////////////////////////////////////
test("Binary classification with continuous features: split calculation") {
- val arr = OldDTSuite.generateOrderedLabeledPointsWithLabel1().map(_.asML)
+ val arr = OldDTSuite.generateOrderedLabeledPointsWithLabel1().map(_.asML.toInstance(1.0))
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2, 100)
@@ -55,7 +56,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
}
test("Binary classification with binary (ordered) categorical features: split calculation") {
- val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML)
+ val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML.toInstance(1.0))
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 2,
@@ -72,7 +73,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
test("Binary classification with 3-ary (ordered) categorical features," +
" with no samples for one category: split calculation") {
- val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML)
+ val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML.toInstance(1.0))
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 2,
@@ -90,12 +91,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
test("find splits for a continuous feature") {
// find splits for normal case
{
- val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+ val fakeMetadata = new DecisionTreeMetadata(1, 0, 0.0, 0, 0,
Map(), Set(),
Array(6), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0, 0
+ 0, 0, 0.0, 0.0, 0, 0
)
- val featureSamples = Array.fill(200000)(math.random)
+ val featureSamples = Array.fill(200000)((1.0, math.random))
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
assert(splits.length === 5)
assert(fakeMetadata.numSplits(0) === 5)
@@ -107,12 +108,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
// find splits should not return identical splits
// when there are not enough split candidates, reduce the number of splits in metadata
{
- val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+ val fakeMetadata = new DecisionTreeMetadata(1, 0, 0.0, 0, 0,
Map(), Set(),
Array(5), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0, 0
+ 0, 0, 0.0, 0.0, 0, 0
)
- val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble)
+ val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(x => (1.0, x.toDouble))
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
assert(splits === Array(1.0, 2.0))
// check returned splits are distinct
@@ -121,47 +122,67 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
// find splits when most samples close to the minimum
{
- val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+ val fakeMetadata = new DecisionTreeMetadata(1, 0, 0.0, 0, 0,
Map(), Set(),
Array(3), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0, 0
+ 0, 0, 0.0, 0.0, 0, 0
)
- val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble)
+ val featureSamples =
+ Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(x => (1.0, x.toDouble))
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
assert(splits === Array(2.0, 3.0))
}
// find splits when most samples close to the maximum
{
- val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+ val fakeMetadata = new DecisionTreeMetadata(1, 0, 0.0, 0, 0,
Map(), Set(),
Array(2), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0, 0
+ 0, 0, 0.0, 0.0, 0, 0
)
- val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble)
+ val featureSamples =
+ Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(x => (1.0, x.toDouble))
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
assert(splits === Array(1.0))
}
- // find splits for constant feature
+ // find splits for arbitrarily scaled data
{
- val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+ val fakeMetadata = new DecisionTreeMetadata(1, 0, 0.0, 0, 0,
+ Map(), Set(),
+ Array(6), Gini, QuantileStrategy.Sort,
+ 0, 0, 0.0, 0.0, 0, 0
+ )
+ val featureSamplesUnitWeight = Array.fill(10)((1.0, math.random))
+ val featureSamplesSmallWeight = featureSamplesUnitWeight.map { case (w, x) => (w * 0.001, x)}
+ val featureSamplesLargeWeight = featureSamplesUnitWeight.map { case (w, x) => (w * 1000, x)}
+ val splitsUnitWeight = RandomForest
+ .findSplitsForContinuousFeature(featureSamplesUnitWeight, fakeMetadata, 0)
+ val splitsSmallWeight = RandomForest
+ .findSplitsForContinuousFeature(featureSamplesSmallWeight, fakeMetadata, 0)
+ val splitsLargeWeight = RandomForest
+ .findSplitsForContinuousFeature(featureSamplesLargeWeight, fakeMetadata, 0)
+ assert(splitsUnitWeight === splitsSmallWeight)
+ assert(splitsUnitWeight === splitsLargeWeight)
+ }
+
+ // find splits when most weight is close to the minimum
+ {
+ val fakeMetadata = new DecisionTreeMetadata(1, 0, 0.0, 0, 0,
Map(), Set(),
Array(3), Gini, QuantileStrategy.Sort,
- 0, 0, 0.0, 0, 0
+ 0, 0, 0.0, 0.0, 0, 0
)
- val featureSamples = Array(0, 0, 0).map(_.toDouble)
- val featureSamplesEmpty = Array.empty[Double]
+ val featureSamples = Array((10, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6)).map {
+ case (w, x) => (w.toDouble, x.toDouble)
+ }
val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
- assert(splits === Array.empty[Double])
- val splitsEmpty =
- RandomForest.findSplitsForContinuousFeature(featureSamplesEmpty, fakeMetadata, 0)
- assert(splitsEmpty === Array.empty[Double])
+ assert(splits === Array(1.0, 2.0))
}
}
test("train with empty arrays") {
- val lp = LabeledPoint(1.0, Vectors.dense(Array.empty[Double]))
+ val lp = LabeledPoint(1.0, Vectors.dense(Array.empty[Double])).toInstance(1.0)
val data = Array.fill(5)(lp)
val rdd = sc.parallelize(data)
@@ -176,8 +197,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
}
test("train with constant features") {
- val lp = LabeledPoint(1.0, Vectors.dense(0.0, 0.0, 0.0))
- val data = Array.fill(5)(lp)
+ val instance = LabeledPoint(1.0, Vectors.dense(0.0, 0.0, 0.0)).toInstance(1.0)
+ val data = Array.fill(5)(instance)
val rdd = sc.parallelize(data)
val strategy = new OldStrategy(
OldAlgo.Classification,
@@ -189,7 +210,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
val Array(tree) = RandomForest.run(rdd, strategy, 1, "all", 42L, instr = None)
assert(tree.rootNode.impurity === -1.0)
assert(tree.depth === 0)
- assert(tree.rootNode.prediction === lp.label)
+ assert(tree.rootNode.prediction === instance.label)
// Test with no categorical features
val strategy2 = new OldStrategy(
@@ -200,11 +221,11 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
val Array(tree2) = RandomForest.run(rdd, strategy2, 1, "all", 42L, instr = None)
assert(tree2.rootNode.impurity === -1.0)
assert(tree2.depth === 0)
- assert(tree2.rootNode.prediction === lp.label)
+ assert(tree2.rootNode.prediction === instance.label)
}
test("Multiclass classification with unordered categorical features: split calculations") {
- val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML)
+ val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML.toInstance(1.0))
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new OldStrategy(
@@ -245,7 +266,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
}
test("Multiclass classification with ordered categorical features: split calculations") {
- val arr = OldDTSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures().map(_.asML)
+ val arr = OldDTSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
+ .map(_.asML.toInstance(1.0))
assert(arr.length === 3000)
val rdd = sc.parallelize(arr)
val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 100,
@@ -277,7 +299,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)),
LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)),
LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
- val input = sc.parallelize(arr)
+ val input = sc.parallelize(arr.map(_.toInstance(1.0)))
val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 1,
numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
@@ -319,7 +341,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)),
LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)),
LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
- val input = sc.parallelize(arr)
+ val input = sc.parallelize(arr.map(_.toInstance(1.0)))
val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 5,
numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
@@ -371,7 +393,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
LabeledPoint(0.0, Vectors.dense(2.0)),
LabeledPoint(0.0, Vectors.dense(2.0)),
LabeledPoint(1.0, Vectors.dense(2.0)))
- val input = sc.parallelize(arr)
+ val input = sc.parallelize(arr.map(_.toInstance(1.0)))
// Must set maxBins s.t. the feature will be treated as an ordered categorical feature.
val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 1,
@@ -390,7 +412,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
}
test("Second level node building with vs. without groups") {
- val arr = OldDTSuite.generateOrderedLabeledPoints().map(_.asML)
+ val arr = OldDTSuite.generateOrderedLabeledPoints().map(_.asML.toInstance(1.0))
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
// For tree with 1 group
@@ -434,7 +456,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
def binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy: OldStrategy) {
val numFeatures = 50
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures, 1000)
- val rdd = sc.parallelize(arr).map(_.asML)
+ val rdd = sc.parallelize(arr).map(_.asML.toInstance(1.0))
// Select feature subset for top nodes. Return true if OK.
def checkFeatureSubsetStrategy(
@@ -547,16 +569,16 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
left2 parent
left right
*/
- val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0))
+ val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0), 6L)
val left = new LeafNode(0.0, leftImp.calculate(), leftImp)
- val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0))
+ val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0), 8L)
val right = new LeafNode(2.0, rightImp.calculate(), rightImp)
val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5))
val parentImp = parent.impurityStats
- val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0))
+ val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0), 8L)
val left2 = new LeafNode(0.0, left2Imp.calculate(), left2Imp)
val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0))
@@ -602,6 +624,57 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
}
+
+ test("weights at arbitrary scale") {
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(3, 10)
+ val rddWithUnitWeights = sc.parallelize(arr.map(_.asML.toInstance(1.0)))
+ val rddWithSmallWeights = rddWithUnitWeights.map { inst =>
+ Instance(inst.label, 0.001, inst.features)
+ }
+ val rddWithBigWeights = rddWithUnitWeights.map { inst =>
+ Instance(inst.label, 1000, inst.features)
+ }
+ val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2)
+ val unitWeightTrees = RandomForest.run(rddWithUnitWeights, strategy, 3, "all", 42L, None)
+
+ val smallWeightTrees = RandomForest.run(rddWithSmallWeights, strategy, 3, "all", 42L, None)
+ unitWeightTrees.zip(smallWeightTrees).foreach { case (unitTree, smallWeightTree) =>
+ TreeTests.checkEqual(unitTree, smallWeightTree)
+ }
+
+ val bigWeightTrees = RandomForest.run(rddWithBigWeights, strategy, 3, "all", 42L, None)
+ unitWeightTrees.zip(bigWeightTrees).foreach { case (unitTree, bigWeightTree) =>
+ TreeTests.checkEqual(unitTree, bigWeightTree)
+ }
+ }
+
+ test("minWeightFraction and minInstancesPerNode") {
+ val data = Array(
+ Instance(0.0, 1.0, Vectors.dense(0.0)),
+ Instance(0.0, 1.0, Vectors.dense(0.0)),
+ Instance(0.0, 1.0, Vectors.dense(0.0)),
+ Instance(0.0, 1.0, Vectors.dense(0.0)),
+ Instance(1.0, 0.1, Vectors.dense(1.0))
+ )
+ val rdd = sc.parallelize(data)
+ val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2,
+ minWeightFractionPerNode = 0.5)
+ val Array(tree1) = RandomForest.run(rdd, strategy, 1, "all", 42L, None)
+ assert(tree1.depth == 0)
+
+ strategy.minWeightFractionPerNode = 0.0
+ val Array(tree2) = RandomForest.run(rdd, strategy, 1, "all", 42L, None)
+ assert(tree2.depth == 1)
+
+ strategy.minInstancesPerNode = 2
+ val Array(tree3) = RandomForest.run(rdd, strategy, 1, "all", 42L, None)
+ assert(tree3.depth == 0)
+
+ strategy.minInstancesPerNode = 1
+ val Array(tree4) = RandomForest.run(rdd, strategy, 1, "all", 42L, None)
+ assert(tree4.depth == 1)
+ }
+
}
private object RandomForestSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
index c90cb8ca1034..999aa80b7750 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
@@ -18,13 +18,15 @@
package org.apache.spark.ml.tree.impl
import scala.collection.JavaConverters._
+import scala.util.Random
import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
-import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.tree._
+import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SparkSession}
@@ -32,6 +34,7 @@ private[ml] object TreeTests extends SparkFunSuite {
/**
* Convert the given data to a DataFrame, and set the features and label metadata.
+ *
* @param data Dataset. Categorical features and labels must already have 0-based indices.
* This must be non-empty.
* @param categoricalFeatures Map: categorical feature index -> number of distinct values
@@ -39,7 +42,7 @@ private[ml] object TreeTests extends SparkFunSuite {
* @return DataFrame with metadata
*/
def setMetadata(
- data: RDD[LabeledPoint],
+ data: RDD[Instance],
categoricalFeatures: Map[Int, Int],
numClasses: Int): DataFrame = {
val spark = SparkSession.builder()
@@ -66,7 +69,7 @@ private[ml] object TreeTests extends SparkFunSuite {
}
val labelMetadata = labelAttribute.toMetadata()
df.select(df("features").as("features", featuresMetadata),
- df("label").as("label", labelMetadata))
+ df("label").as("label", labelMetadata), df("weight"))
}
/** Java-friendly version of [[setMetadata()]] */
@@ -74,12 +77,14 @@ private[ml] object TreeTests extends SparkFunSuite {
data: JavaRDD[LabeledPoint],
categoricalFeatures: java.util.Map[java.lang.Integer, java.lang.Integer],
numClasses: Int): DataFrame = {
- setMetadata(data.rdd, categoricalFeatures.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
+ setMetadata(data.rdd.map(_.toInstance(1.0)),
+ categoricalFeatures.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
numClasses)
}
/**
* Set label metadata (particularly the number of classes) on a DataFrame.
+ *
* @param data Dataset. Categorical features and labels must already have 0-based indices.
* This must be non-empty.
* @param numClasses Number of classes label can take. If 0, mark as continuous.
@@ -124,8 +129,8 @@ private[ml] object TreeTests extends SparkFunSuite {
* make mistakes such as creating loops of Nodes.
*/
private def checkEqual(a: Node, b: Node): Unit = {
- assert(a.prediction === b.prediction)
- assert(a.impurity === b.impurity)
+ assert(a.prediction ~== b.prediction absTol 1e-8)
+ assert(a.impurity ~== b.impurity absTol 1e-8)
(a, b) match {
case (aye: InternalNode, bee: InternalNode) =>
assert(aye.split === bee.split)
@@ -156,6 +161,7 @@ private[ml] object TreeTests extends SparkFunSuite {
/**
* Helper method for constructing a tree for testing.
* Given left, right children, construct a parent node.
+ *
* @param split Split for parent node
* @return Parent node with children attached
*/
@@ -163,8 +169,8 @@ private[ml] object TreeTests extends SparkFunSuite {
val leftImp = left.impurityStats
val rightImp = right.impurityStats
val parentImp = leftImp.copy.add(rightImp)
- val leftWeight = leftImp.count / parentImp.count.toDouble
- val rightWeight = rightImp.count / parentImp.count.toDouble
+ val leftWeight = leftImp.count / parentImp.count
+ val rightWeight = rightImp.count / parentImp.count
val gain = parentImp.calculate() -
(leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate())
val pred = parentImp.predict
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
index f1ed568d5e60..4b1bd35313c5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
@@ -23,7 +23,7 @@ import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasWeightCol}
+import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.recommendation.{ALS, ALSModel}
import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
@@ -246,8 +246,8 @@ object MLTestingUtils extends SparkFunSuite {
seed: Long): Unit = {
val (overSampledData, weightedData) = genEquivalentOversampledAndWeightedInstances(
data, seed)
- val weightedModel = estimator.set(estimator.weightCol, "weight").fit(weightedData)
val overSampledModel = estimator.set(estimator.weightCol, "").fit(overSampledData)
+ val weightedModel = estimator.set(estimator.weightCol, "weight").fit(weightedData)
modelEquals(weightedModel, overSampledModel)
}
@@ -260,15 +260,17 @@ object MLTestingUtils extends SparkFunSuite {
data: Dataset[LabeledPoint],
estimator: E with HasWeightCol,
numClasses: Int,
- modelEquals: (M, M) => Unit): Unit = {
+ modelEquals: (M, M) => Unit,
+ outlierRatio: Int): Unit = {
import data.sqlContext.implicits._
val outlierDS = data.withColumn("weight", lit(1.0)).as[Instance].flatMap {
case Instance(l, w, f) =>
val outlierLabel = if (numClasses == 0) -l else numClasses - l - 1
- List.fill(3)(Instance(outlierLabel, 0.0001, f)) ++ List(Instance(l, w, f))
+ List.fill(outlierRatio)(Instance(outlierLabel, 0.0001, f)) ++ List(Instance(l, w, f))
}
val trueModel = estimator.set(estimator.weightCol, "").fit(data)
- val outlierModel = estimator.set(estimator.weightCol, "weight").fit(outlierDS)
+ val outlierModel = estimator.set(estimator.weightCol, "weight")
+ .fit(outlierDS)
modelEquals(trueModel, outlierModel)
}
@@ -281,10 +283,26 @@ object MLTestingUtils extends SparkFunSuite {
estimator: E with HasWeightCol,
modelEquals: (M, M) => Unit): Unit = {
estimator.set(estimator.weightCol, "weight")
- val models = Seq(0.001, 1.0, 1000.0).map { w =>
+ val models = Seq(0.01, 1.0, 1000.0).map { w =>
val df = data.withColumn("weight", lit(w))
estimator.fit(df)
}
models.sliding(2).foreach { case Seq(m1, m2) => modelEquals(m1, m2)}
}
+
+ def modelPredictionEquals[M <: PredictionModel[_, M]](
+ data: DataFrame,
+ compareFunc: (Double, Double) => Boolean,
+ fractionInTol: Double)(
+ model1: M,
+ model2: M): Unit = {
+ val pred1 = model1.transform(data).select(model1.getPredictionCol).collect()
+ val pred2 = model2.transform(data).select(model2.getPredictionCol).collect()
+ val inTol = pred1.zip(pred2).count { case (p1, p2) =>
+ val x = p1.getDouble(0)
+ val y = p2.getDouble(0)
+ compareFunc(x, y)
+ }
+ assert(inTol / pred1.length.toDouble >= fractionInTol)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 441d0f7614bf..f6efc84a2418 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -73,7 +73,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
- val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance(1.0)), strategy)
assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1))
@@ -100,7 +100,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
maxDepth = 2,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2))
- val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance(1.0)), strategy)
assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1))
@@ -116,7 +116,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Gini, maxDepth = 3,
numClasses = 2, maxBins = 100)
- val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance(1.0)), strategy)
assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1))
@@ -133,7 +133,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Gini, maxDepth = 3,
numClasses = 2, maxBins = 100)
- val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance(1.0)), strategy)
assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1))
@@ -150,7 +150,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Entropy, maxDepth = 3,
numClasses = 2, maxBins = 100)
- val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance(1.0)), strategy)
assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1))
@@ -167,7 +167,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Entropy, maxDepth = 3,
numClasses = 2, maxBins = 100)
- val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance(1.0)), strategy)
assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1))
@@ -183,7 +183,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClasses = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
- val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance(1.0)), strategy)
assert(strategy.isMulticlassClassification)
assert(metadata.isUnordered(featureIndex = 0))
assert(metadata.isUnordered(featureIndex = 1))
@@ -240,7 +240,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
numClasses = 3, maxBins = maxBins,
categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
assert(strategy.isMulticlassClassification)
- val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance(1.0)), strategy)
assert(metadata.isUnordered(featureIndex = 0))
assert(metadata.isUnordered(featureIndex = 1))
@@ -288,7 +288,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClasses = 3, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3))
assert(strategy.isMulticlassClassification)
- val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance(1.0)), strategy)
assert(metadata.isUnordered(featureIndex = 0))
val model = DecisionTree.train(rdd, strategy)
@@ -310,7 +310,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
numClasses = 3, maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
assert(strategy.isMulticlassClassification)
- val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance(1.0)), strategy)
assert(!metadata.isUnordered(featureIndex = 0))
assert(!metadata.isUnordered(featureIndex = 1))
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala
index 14152cdd63bc..d4171cf441e9 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala
@@ -18,23 +18,62 @@
package org.apache.spark.mllib.tree
import org.apache.spark.SparkFunSuite
-import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator}
+import org.apache.spark.ml.util.TestingUtils._
+import org.apache.spark.mllib.tree.impurity._
/**
* Test suites for [[GiniAggregator]] and [[EntropyAggregator]].
*/
class ImpuritySuite extends SparkFunSuite {
+
+ private val seed = 42
+
test("Gini impurity does not support negative labels") {
val gini = new GiniAggregator(2)
intercept[IllegalArgumentException] {
- gini.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0)
+ gini.update(Array(0.0, 1.0, 2.0), 0, -1, 3, 0.0)
}
}
test("Entropy does not support negative labels") {
val entropy = new EntropyAggregator(2)
intercept[IllegalArgumentException] {
- entropy.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0)
+ entropy.update(Array(0.0, 1.0, 2.0), 0, -1, 3, 0.0)
+ }
+ }
+
+ test("Classification impurities are insensitive to scaling") {
+ val rng = new scala.util.Random(seed)
+ val weightedCounts = Array.fill(5)(rng.nextDouble())
+ val smallWeightedCounts = weightedCounts.map(_ * 0.0001)
+ val largeWeightedCounts = weightedCounts.map(_ * 10000)
+ Seq(Gini, Entropy).foreach { impurity =>
+ val impurity1 = impurity.calculate(weightedCounts, weightedCounts.sum)
+ assert(impurity.calculate(smallWeightedCounts, smallWeightedCounts.sum)
+ ~== impurity1 relTol 0.005)
+ assert(impurity.calculate(largeWeightedCounts, largeWeightedCounts.sum)
+ ~== impurity1 relTol 0.005)
}
}
+ test("Regression impurities are insensitive to scaling") {
+ def computeStats(samples: Seq[Double], weights: Seq[Double]): (Double, Double, Double) = {
+ samples.zip(weights).foldLeft((0.0, 0.0, 0.0)) { case ((wn, wy, wyy), (y, w)) =>
+ (wn + w, wy + w * y, wyy + w * y * y)
+ }
+ }
+ val rng = new scala.util.Random(seed)
+ val samples = Array.fill(10)(rng.nextDouble())
+ val _weights = Array.fill(10)(rng.nextDouble())
+ val smallWeights = _weights.map(_ * 0.0001)
+ val largeWeights = _weights.map(_ * 10000)
+ val (count, sum, sumSquared) = computeStats(samples, _weights)
+ Seq(Variance).foreach { impurity =>
+ val impurity1 = impurity.calculate(count, sum, sumSquared)
+ val (smallCount, smallSum, smallSumSquared) = computeStats(samples, smallWeights)
+ val (largeCount, largeSum, largeSumSquared) = computeStats(samples, largeWeights)
+ assert(impurity.calculate(smallCount, smallSum, smallSumSquared) ~== impurity1 relTol 0.005)
+ assert(impurity.calculate(largeCount, largeSum, largeSumSquared) ~== impurity1 relTol 0.005)
+ }
+ }
+
}
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 7e6e14352338..c41b67b8f6bf 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -36,6 +36,9 @@ object MimaExcludes {
// Exclude rules for 2.2.x
lazy val v22excludes = v21excludes ++ Seq(
+ // [SPARK-9478][ML] Add sample weights to decision trees
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.tree.configuration.Strategy.this"),
+
// [SPARK-18663][SQL] Simplify CountMinSketch aggregate implementation
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.util.sketch.CountMinSketch.toByteArray"),