diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala b/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala index 6c79d77f142e..fb98bcb9d6b0 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala @@ -52,7 +52,7 @@ object TestingUtils { /** * Private helper function for comparing two values using absolute tolerance. */ - private def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = { + private[ml] def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = { // Special case for NaNs if (x.isNaN && y.isNaN) { return true diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 7e5790ab70ee..e35e6ce7fdad 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -77,17 +77,37 @@ abstract class Classifier[ * @note Throws `SparkException` if any label is a non-integer or is negative */ protected def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = { - require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" + - s" $numClasses, but requires numClasses > 0.") + validateNumClasses(numClasses) dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { case Row(label: Double, features: Vector) => - require(label % 1 == 0 && label >= 0 && label < numClasses, s"Classifier was given" + - s" dataset with invalid label $label. Labels must be integers in range" + - s" [0, $numClasses).") + validateLabel(label, numClasses) LabeledPoint(label, features) } } + /** + * Validates that number of classes is greater than zero. + * + * @param numClasses Number of classes label can take. + */ + protected def validateNumClasses(numClasses: Int): Unit = { + require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" + + s" $numClasses, but requires numClasses > 0.") + } + + /** + * Validates the label on the classifier is a valid integer in the range [0, numClasses). + * + * @param label The label to validate. + * @param numClasses Number of classes label can take. Labels must be integers in the range + * [0, numClasses). + */ + protected def validateLabel(label: Double, numClasses: Int): Unit = { + require(label.toLong == label && label >= 0 && label < numClasses, s"Classifier was given" + + s" dataset with invalid label $label. Labels must be integers in range" + + s" [0, $numClasses).") + } + /** * Get the number of classes. This looks in column metadata first, and if that is missing, * then this assumes classes are indexed 0,1,...,numClasses-1 and computes numClasses diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index d9292a547676..200ac0032e51 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -22,10 +22,12 @@ import org.json4s.{DefaultFormats, JObject} import org.json4s.JsonDSL._ import org.apache.spark.annotation.Since -import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.shared.HasWeightCol import org.apache.spark.ml.tree._ +import org.apache.spark.ml.tree.{DecisionTreeModel, Node, TreeClassifierParams} import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._ import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util._ @@ -33,8 +35,9 @@ import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Dataset - +import org.apache.spark.sql.{Dataset, Row} +import org.apache.spark.sql.functions.{col, lit} +import org.apache.spark.sql.types.DoubleType /** * Decision tree learning algorithm (http://en.wikipedia.org/wiki/Decision_tree_learning) @@ -66,6 +69,9 @@ class DecisionTreeClassifier @Since("1.4.0") ( def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ + @Since("3.0.0") + def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value) + @Since("1.4.0") def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) @@ -97,6 +103,16 @@ class DecisionTreeClassifier @Since("1.4.0") ( @Since("1.6.0") def setSeed(value: Long): this.type = set(seed, value) + /** + * Sets the value of param [[weightCol]]. + * If this is not set or empty, we treat all instance weights as 1.0. + * Default is not set, so all instances have weight one. + * + * @group setParam + */ + @Since("3.0.0") + def setWeightCol(value: String): this.type = set(weightCol, value) + override protected def train( dataset: Dataset[_]): DecisionTreeClassificationModel = instrumented { instr => instr.logPipelineStage(this) @@ -104,22 +120,27 @@ class DecisionTreeClassifier @Since("1.4.0") ( val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val numClasses: Int = getNumClasses(dataset) - instr.logNumClasses(numClasses) if (isDefined(thresholds)) { require($(thresholds).length == numClasses, this.getClass.getSimpleName + ".train() called with non-matching numClasses and thresholds.length." + s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") } - - val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses) + validateNumClasses(numClasses) + val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) + val instances = + dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { + case Row(label: Double, weight: Double, features: Vector) => + validateLabel(label, numClasses) + Instance(label, weight, features) + } val strategy = getOldStrategy(categoricalFeatures, numClasses) - + instr.logNumClasses(numClasses) instr.logParams(this, labelCol, featuresCol, predictionCol, rawPredictionCol, probabilityCol, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB, cacheNodeIds, checkpointInterval, impurity, seed) - val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", + val trees = RandomForest.run(instances, strategy, numTrees = 1, featureSubsetStrategy = "all", seed = $(seed), instr = Some(instr), parentUID = Some(uid)) trees.head.asInstanceOf[DecisionTreeClassificationModel] @@ -128,13 +149,13 @@ class DecisionTreeClassifier @Since("1.4.0") ( /** (private[ml]) Train a decision tree on an RDD */ private[ml] def train(data: RDD[LabeledPoint], oldStrategy: OldStrategy): DecisionTreeClassificationModel = instrumented { instr => + val instances = data.map(_.toInstance) instr.logPipelineStage(this) - instr.logDataset(data) + instr.logDataset(instances) instr.logParams(this, maxDepth, maxBins, minInstancesPerNode, minInfoGain, maxMemoryInMB, cacheNodeIds, checkpointInterval, impurity, seed) - - val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all", - seed = 0L, instr = Some(instr), parentUID = Some(uid)) + val trees = RandomForest.run(instances, oldStrategy, numTrees = 1, + featureSubsetStrategy = "all", seed = 0L, instr = Some(instr), parentUID = Some(uid)) trees.head.asInstanceOf[DecisionTreeClassificationModel] } @@ -180,6 +201,7 @@ class DecisionTreeClassificationModel private[ml] ( /** * Construct a decision tree classification model. + * * @param rootNode Root node of tree, with other nodes attached. */ private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) = diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 0a3bfd1f85e0..3500f2ad52a5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -21,20 +21,21 @@ import org.json4s.{DefaultFormats, JObject} import org.json4s.JsonDSL._ import org.apache.spark.annotation.Since -import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree._ +import org.apache.spark.ml.tree.{RandomForestParams, TreeClassifierParams, TreeEnsembleModel} import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset} -import org.apache.spark.sql.functions._ - +import org.apache.spark.sql.functions.{col, udf} /** * Random Forest learning algorithm for @@ -130,7 +131,7 @@ class RandomForestClassifier @Since("1.4.0") ( s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") } - val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses) + val instances: RDD[Instance] = extractLabeledPoints(dataset, numClasses).map(_.toInstance) val strategy = super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity) @@ -139,7 +140,7 @@ class RandomForestClassifier @Since("1.4.0") ( minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds, checkpointInterval) val trees = RandomForest - .run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr)) + .run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr)) .map(_.asInstanceOf[DecisionTreeClassificationModel]) val numFeatures = trees.head.numFeatures diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala index 412954f7b2d5..f6667b73304a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LabeledPoint.scala @@ -37,4 +37,13 @@ case class LabeledPoint(@Since("2.0.0") label: Double, @Since("2.0.0") features: override def toString: String = { s"($label,$features)" } + + private[spark] def toInstance(weight: Double): Instance = { + Instance(label, weight, features) + } + + private[spark] def toInstance: Instance = { + Instance(label, 1.0, features) + } + } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index faadc4d7b4cc..525479151aa0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -23,9 +23,10 @@ import org.json4s.JsonDSL._ import org.apache.spark.annotation.Since import org.apache.spark.ml.{PredictionModel, Predictor} -import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.shared.HasWeightCol import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.DecisionTreeModelReadWrite._ import org.apache.spark.ml.tree.impl.RandomForest @@ -34,8 +35,9 @@ import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.DoubleType /** @@ -65,6 +67,9 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) /** @group setParam */ + @Since("3.0.0") + def setMinWeightFractionPerNode(value: Double): this.type = set(minWeightFractionPerNode, value) + @Since("1.4.0") def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) @@ -100,18 +105,33 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S @Since("2.0.0") def setVarianceCol(value: String): this.type = set(varianceCol, value) + /** + * Sets the value of param [[weightCol]]. + * If this is not set or empty, we treat all instance weights as 1.0. + * Default is not set, so all instances have weight one. + * + * @group setParam + */ + @Since("3.0.0") + def setWeightCol(value: String): this.type = set(weightCol, value) + override protected def train( dataset: Dataset[_]): DecisionTreeRegressionModel = instrumented { instr => val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) - val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) + val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) + val instances = + dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { + case Row(label: Double, weight: Double, features: Vector) => + Instance(label, weight, features) + } val strategy = getOldStrategy(categoricalFeatures) instr.logPipelineStage(this) - instr.logDataset(oldDataset) + instr.logDataset(instances) instr.logParams(this, params: _*) - val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", + val trees = RandomForest.run(instances, strategy, numTrees = 1, featureSubsetStrategy = "all", seed = $(seed), instr = Some(instr), parentUID = Some(uid)) trees.head.asInstanceOf[DecisionTreeRegressionModel] @@ -126,8 +146,9 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S instr.logDataset(data) instr.logParams(this, params: _*) - val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy, - seed = $(seed), instr = Some(instr), parentUID = Some(uid)) + val instances = data.map(_.toInstance) + val trees = RandomForest.run(instances, oldStrategy, numTrees = 1, + featureSubsetStrategy, seed = $(seed), instr = Some(instr), parentUID = Some(uid)) trees.head.asInstanceOf[DecisionTreeRegressionModel] } @@ -155,6 +176,7 @@ object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor * * Decision tree (Wikipedia) model for regression. * It supports both continuous and categorical features. + * * @param rootNode Root of the decision tree */ @Since("1.4.0") @@ -173,6 +195,7 @@ class DecisionTreeRegressionModel private[ml] ( /** * Construct a decision tree regression model. + * * @param rootNode Root node of tree, with other nodes attached. */ private[ml] def this(rootNode: Node, numFeatures: Int) = diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index afa9a646412b..6f36dfb9ff51 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -22,7 +22,6 @@ import org.json4s.JsonDSL._ import org.apache.spark.annotation.Since import org.apache.spark.ml.{PredictionModel, Predictor} -import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree._ @@ -32,10 +31,8 @@ import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} -import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset} -import org.apache.spark.sql.functions._ - +import org.apache.spark.sql.functions.{col, udf} /** * Random Forest @@ -119,18 +116,19 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S dataset: Dataset[_]): RandomForestRegressionModel = instrumented { instr => val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) - val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) + + val instances = extractLabeledPoints(dataset).map(_.toInstance) val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity) instr.logPipelineStage(this) - instr.logDataset(dataset) + instr.logDataset(instances) instr.logParams(this, labelCol, featuresCol, predictionCol, impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain, minInstancesPerNode, seed, subsamplingRate, cacheNodeIds, checkpointInterval) val trees = RandomForest - .run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr)) + .run(instances, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr)) .map(_.asInstanceOf[DecisionTreeRegressionModel]) val numFeatures = trees.head.numFeatures diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala index 4e372702f0c6..c896b1589a93 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala @@ -33,13 +33,13 @@ import org.apache.spark.util.random.XORShiftRandom * this datum has 1 copy, 0 copies, and 4 copies in the 3 subsamples, respectively. * * @param datum Data instance - * @param subsampleWeights Weight of this instance in each subsampled dataset. - * - * TODO: This does not currently support (Double) weighted instances. Once MLlib has weighted - * dataset support, update. (We store subsampleWeights as Double for this future extension.) + * @param subsampleCounts Number of samples of this instance in each subsampled dataset. + * @param sampleWeight The weight of this instance. */ -private[spark] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double]) - extends Serializable +private[spark] class BaggedPoint[Datum]( + val datum: Datum, + val subsampleCounts: Array[Int], + val sampleWeight: Double = 1.0) extends Serializable private[spark] object BaggedPoint { @@ -52,6 +52,7 @@ private[spark] object BaggedPoint { * @param subsamplingRate Fraction of the training data used for learning decision tree. * @param numSubsamples Number of subsamples of this RDD to take. * @param withReplacement Sampling with/without replacement. + * @param extractSampleWeight A function to get the sample weight of each datum. * @param seed Random seed. * @return BaggedPoint dataset representation. */ @@ -60,12 +61,14 @@ private[spark] object BaggedPoint { subsamplingRate: Double, numSubsamples: Int, withReplacement: Boolean, + extractSampleWeight: (Datum => Double) = (_: Datum) => 1.0, seed: Long = Utils.random.nextLong()): RDD[BaggedPoint[Datum]] = { + // TODO: implement weighted bootstrapping if (withReplacement) { convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed) } else { if (numSubsamples == 1 && subsamplingRate == 1.0) { - convertToBaggedRDDWithoutSampling(input) + convertToBaggedRDDWithoutSampling(input, extractSampleWeight) } else { convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed) } @@ -82,16 +85,15 @@ private[spark] object BaggedPoint { val rng = new XORShiftRandom rng.setSeed(seed + partitionIndex + 1) instances.map { instance => - val subsampleWeights = new Array[Double](numSubsamples) + val subsampleCounts = new Array[Int](numSubsamples) var subsampleIndex = 0 while (subsampleIndex < numSubsamples) { - val x = rng.nextDouble() - subsampleWeights(subsampleIndex) = { - if (x < subsamplingRate) 1.0 else 0.0 + if (rng.nextDouble() < subsamplingRate) { + subsampleCounts(subsampleIndex) = 1 } subsampleIndex += 1 } - new BaggedPoint(instance, subsampleWeights) + new BaggedPoint(instance, subsampleCounts) } } } @@ -106,20 +108,20 @@ private[spark] object BaggedPoint { val poisson = new PoissonDistribution(subsample) poisson.reseedRandomGenerator(seed + partitionIndex + 1) instances.map { instance => - val subsampleWeights = new Array[Double](numSubsamples) + val subsampleCounts = new Array[Int](numSubsamples) var subsampleIndex = 0 while (subsampleIndex < numSubsamples) { - subsampleWeights(subsampleIndex) = poisson.sample() + subsampleCounts(subsampleIndex) = poisson.sample() subsampleIndex += 1 } - new BaggedPoint(instance, subsampleWeights) + new BaggedPoint(instance, subsampleCounts) } } } private def convertToBaggedRDDWithoutSampling[Datum] ( - input: RDD[Datum]): RDD[BaggedPoint[Datum]] = { - input.map(datum => new BaggedPoint(datum, Array(1.0))) + input: RDD[Datum], + extractSampleWeight: (Datum => Double)): RDD[BaggedPoint[Datum]] = { + input.map(datum => new BaggedPoint(datum, Array(1), extractSampleWeight(datum))) } - } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala index 5aeea1443d49..17ec161f2f50 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala @@ -104,16 +104,21 @@ private[spark] class DTStatsAggregator( /** * Update the stats for a given (feature, bin) for ordered features, using the given label. */ - def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = { + def update( + featureIndex: Int, + binIndex: Int, + label: Double, + numSamples: Int, + sampleWeight: Double): Unit = { val i = featureOffsets(featureIndex) + binIndex * statsSize - impurityAggregator.update(allStats, i, label, instanceWeight) + impurityAggregator.update(allStats, i, label, numSamples, sampleWeight) } /** * Update the parent node stats using the given label. */ - def updateParent(label: Double, instanceWeight: Double): Unit = { - impurityAggregator.update(parentStats, 0, label, instanceWeight) + def updateParent(label: Double, numSamples: Int, sampleWeight: Double): Unit = { + impurityAggregator.update(parentStats, 0, label, numSamples, sampleWeight) } /** @@ -127,9 +132,10 @@ private[spark] class DTStatsAggregator( featureOffset: Int, binIndex: Int, label: Double, - instanceWeight: Double): Unit = { + numSamples: Int, + sampleWeight: Double): Unit = { impurityAggregator.update(allStats, featureOffset + binIndex * statsSize, - label, instanceWeight) + label, numSamples, sampleWeight) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala index 53189e0797b6..8f8a17171f98 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import scala.util.Try import org.apache.spark.internal.Logging -import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.tree.TreeEnsembleParams import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ @@ -32,16 +32,20 @@ import org.apache.spark.rdd.RDD /** * Learning and dataset metadata for DecisionTree. * + * @param weightedNumExamples Weighted count of samples in the tree. * @param numClasses For classification: labels can take values {0, ..., numClasses - 1}. * For regression: fixed at 0 (no meaning). * @param maxBins Maximum number of bins, for all features. * @param featureArity Map: categorical feature index to arity. * I.e., the feature takes values in {0, ..., arity - 1}. * @param numBins Number of bins for each feature. + * @param minWeightFractionPerNode The minimum fraction of the total sample weight that must be + * present in a leaf node in order to be considered a valid split. */ private[spark] class DecisionTreeMetadata( val numFeatures: Int, val numExamples: Long, + val weightedNumExamples: Double, val numClasses: Int, val maxBins: Int, val featureArity: Map[Int, Int], @@ -51,6 +55,7 @@ private[spark] class DecisionTreeMetadata( val quantileStrategy: QuantileStrategy, val maxDepth: Int, val minInstancesPerNode: Int, + val minWeightFractionPerNode: Double, val minInfoGain: Double, val numTrees: Int, val numFeaturesPerNode: Int) extends Serializable { @@ -67,6 +72,8 @@ private[spark] class DecisionTreeMetadata( def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex) + def minWeightPerNode: Double = minWeightFractionPerNode * weightedNumExamples + /** * Number of splits for the given feature. * For unordered features, there is 1 bin per split. @@ -104,7 +111,7 @@ private[spark] object DecisionTreeMetadata extends Logging { * as well as the number of splits and bins for each feature. */ def buildMetadata( - input: RDD[LabeledPoint], + input: RDD[Instance], strategy: Strategy, numTrees: Int, featureSubsetStrategy: String): DecisionTreeMetadata = { @@ -115,7 +122,11 @@ private[spark] object DecisionTreeMetadata extends Logging { } require(numFeatures > 0, s"DecisionTree requires number of features > 0, " + s"but was given an empty features vector") - val numExamples = input.count() + val (numExamples, weightSum) = input.aggregate((0L, 0.0))( + seqOp = (cw, instance) => (cw._1 + 1L, cw._2 + instance.weight), + combOp = (cw1, cw2) => (cw1._1 + cw2._1, cw1._2 + cw2._2) + ) + val numClasses = strategy.algo match { case Classification => strategy.numClasses case Regression => 0 @@ -206,17 +217,18 @@ private[spark] object DecisionTreeMetadata extends Logging { } } - new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max, - strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins, + new DecisionTreeMetadata(numFeatures, numExamples, weightSum, numClasses, + numBins.max, strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins, strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth, - strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode) + strategy.minInstancesPerNode, strategy.minWeightFractionPerNode, strategy.minInfoGain, + numTrees, numFeaturesPerNode) } /** * Version of [[DecisionTreeMetadata#buildMetadata]] for DecisionTree. */ def buildMetadata( - input: RDD[LabeledPoint], + input: RDD[Instance], strategy: Strategy): DecisionTreeMetadata = { buildMetadata(input, strategy, numTrees = 1, featureSubsetStrategy = "all") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 822abd2d3522..fb4c321a146f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -24,10 +24,12 @@ import scala.util.Random import org.apache.spark.internal.Logging import org.apache.spark.ml.classification.DecisionTreeClassificationModel -import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.impl.Utils import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree._ import org.apache.spark.ml.util.Instrumentation +import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.ImpurityCalculator import org.apache.spark.mllib.tree.model.ImpurityStats @@ -90,6 +92,24 @@ private[spark] object RandomForest extends Logging with Serializable { strategy: OldStrategy, numTrees: Int, featureSubsetStrategy: String, + seed: Long): Array[DecisionTreeModel] = { + val instances = input.map { case LabeledPoint(label, features) => + Instance(label, 1.0, features.asML) + } + run(instances, strategy, numTrees, featureSubsetStrategy, seed, None) + } + + /** + * Train a random forest. + * + * @param input Training data: RDD of `Instance` + * @return an unweighted set of trees + */ + def run( + input: RDD[Instance], + strategy: OldStrategy, + numTrees: Int, + featureSubsetStrategy: String, seed: Long, instr: Option[Instrumentation], prune: Boolean = true, // exposed for testing only, real trees are always pruned @@ -101,9 +121,10 @@ private[spark] object RandomForest extends Logging with Serializable { timer.start("init") - val retaggedInput = input.retag(classOf[LabeledPoint]) + val retaggedInput = input.retag(classOf[Instance]) val metadata = DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy) + instr match { case Some(instrumentation) => instrumentation.logNumFeatures(metadata.numFeatures) @@ -132,7 +153,8 @@ private[spark] object RandomForest extends Logging with Serializable { val withReplacement = numTrees > 1 val baggedInput = BaggedPoint - .convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, withReplacement, seed) + .convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, withReplacement, + (tp: TreePoint) => tp.weight, seed = seed) .persist(StorageLevel.MEMORY_AND_DISK) // depth of the decision tree @@ -254,19 +276,21 @@ private[spark] object RandomForest extends Logging with Serializable { * For unordered features, bins correspond to subsets of categories; either the left or right bin * for each subset is updated. * - * @param agg Array storing aggregate calculation, with a set of sufficient statistics for - * each (feature, bin). - * @param treePoint Data point being aggregated. - * @param splits possible splits indexed (numFeatures)(numSplits) - * @param unorderedFeatures Set of indices of unordered features. - * @param instanceWeight Weight (importance) of instance in dataset. + * @param agg Array storing aggregate calculation, with a set of sufficient statistics for + * each (feature, bin). + * @param treePoint Data point being aggregated. + * @param splits Possible splits indexed (numFeatures)(numSplits) + * @param unorderedFeatures Set of indices of unordered features. + * @param numSamples Number of times this instance occurs in the sample. + * @param sampleWeight Weight (importance) of instance in dataset. */ private def mixedBinSeqOp( agg: DTStatsAggregator, treePoint: TreePoint, splits: Array[Array[Split]], unorderedFeatures: Set[Int], - instanceWeight: Double, + numSamples: Int, + sampleWeight: Double, featuresForNode: Option[Array[Int]]): Unit = { val numFeaturesPerNode = if (featuresForNode.nonEmpty) { // Use subsampled features @@ -293,14 +317,15 @@ private[spark] object RandomForest extends Logging with Serializable { var splitIndex = 0 while (splitIndex < numSplits) { if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) { - agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) + agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, numSamples, + sampleWeight) } splitIndex += 1 } } else { // Ordered feature val binIndex = treePoint.binnedFeatures(featureIndex) - agg.update(featureIndexIdx, binIndex, treePoint.label, instanceWeight) + agg.update(featureIndexIdx, binIndex, treePoint.label, numSamples, sampleWeight) } featureIndexIdx += 1 } @@ -314,12 +339,14 @@ private[spark] object RandomForest extends Logging with Serializable { * @param agg Array storing aggregate calculation, with a set of sufficient statistics for * each (feature, bin). * @param treePoint Data point being aggregated. - * @param instanceWeight Weight (importance) of instance in dataset. + * @param numSamples Number of times this instance occurs in the sample. + * @param sampleWeight Weight (importance) of instance in dataset. */ private def orderedBinSeqOp( agg: DTStatsAggregator, treePoint: TreePoint, - instanceWeight: Double, + numSamples: Int, + sampleWeight: Double, featuresForNode: Option[Array[Int]]): Unit = { val label = treePoint.label @@ -329,7 +356,7 @@ private[spark] object RandomForest extends Logging with Serializable { var featureIndexIdx = 0 while (featureIndexIdx < featuresForNode.get.length) { val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx)) - agg.update(featureIndexIdx, binIndex, label, instanceWeight) + agg.update(featureIndexIdx, binIndex, label, numSamples, sampleWeight) featureIndexIdx += 1 } } else { @@ -338,7 +365,7 @@ private[spark] object RandomForest extends Logging with Serializable { var featureIndex = 0 while (featureIndex < numFeatures) { val binIndex = treePoint.binnedFeatures(featureIndex) - agg.update(featureIndex, binIndex, label, instanceWeight) + agg.update(featureIndex, binIndex, label, numSamples, sampleWeight) featureIndex += 1 } } @@ -427,14 +454,16 @@ private[spark] object RandomForest extends Logging with Serializable { if (nodeInfo != null) { val aggNodeIndex = nodeInfo.nodeIndexInGroup val featuresForNode = nodeInfo.featureSubset - val instanceWeight = baggedPoint.subsampleWeights(treeIndex) + val numSamples = baggedPoint.subsampleCounts(treeIndex) + val sampleWeight = baggedPoint.sampleWeight if (metadata.unorderedFeatures.isEmpty) { - orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode) + orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, numSamples, sampleWeight, + featuresForNode) } else { mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits, - metadata.unorderedFeatures, instanceWeight, featuresForNode) + metadata.unorderedFeatures, numSamples, sampleWeight, featuresForNode) } - agg(aggNodeIndex).updateParent(baggedPoint.datum.label, instanceWeight) + agg(aggNodeIndex).updateParent(baggedPoint.datum.label, numSamples, sampleWeight) } } @@ -594,8 +623,8 @@ private[spark] object RandomForest extends Logging with Serializable { if (!isLeaf) { node.split = Some(split) val childIsLeaf = (LearningNode.indexToLevel(nodeIndex) + 1) == metadata.maxDepth - val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0) - val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0) + val leftChildIsLeaf = childIsLeaf || (math.abs(stats.leftImpurity) < Utils.EPSILON) + val rightChildIsLeaf = childIsLeaf || (math.abs(stats.rightImpurity) < Utils.EPSILON) node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex), leftChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator))) node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex), @@ -659,15 +688,20 @@ private[spark] object RandomForest extends Logging with Serializable { stats.impurity } + val leftRawCount = leftImpurityCalculator.rawCount + val rightRawCount = rightImpurityCalculator.rawCount val leftCount = leftImpurityCalculator.count val rightCount = rightImpurityCalculator.count val totalCount = leftCount + rightCount - // If left child or right child doesn't satisfy minimum instances per node, - // then this split is invalid, return invalid information gain stats. - if ((leftCount < metadata.minInstancesPerNode) || - (rightCount < metadata.minInstancesPerNode)) { + val violatesMinInstancesPerNode = (leftRawCount < metadata.minInstancesPerNode) || + (rightRawCount < metadata.minInstancesPerNode) + val violatesMinWeightPerNode = (leftCount < metadata.minWeightPerNode) || + (rightCount < metadata.minWeightPerNode) + // If left child or right child doesn't satisfy minimum weight per node or minimum + // instances per node, then this split is invalid, return invalid information gain stats. + if (violatesMinInstancesPerNode || violatesMinWeightPerNode) { return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) } @@ -734,7 +768,8 @@ private[spark] object RandomForest extends Logging with Serializable { // Find best split. val (bestFeatureSplitIndex, bestFeatureGainStats) = Range(0, numSplits).map { case splitIdx => - val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) + val leftChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) rightChildStats.subtract(leftChildStats) @@ -876,14 +911,14 @@ private[spark] object RandomForest extends Logging with Serializable { * and for multiclass classification with a high-arity feature, * there is one bin per category. * - * @param input Training data: RDD of [[LabeledPoint]] + * @param input Training data: RDD of [[Instance]] * @param metadata Learning and dataset metadata * @param seed random seed * @return Splits, an Array of [[Split]] * of size (numFeatures, numSplits) */ protected[tree] def findSplits( - input: RDD[LabeledPoint], + input: RDD[Instance], metadata: DecisionTreeMetadata, seed: Long): Array[Array[Split]] = { @@ -898,14 +933,14 @@ private[spark] object RandomForest extends Logging with Serializable { logDebug("fraction of data used for calculating quantiles = " + fraction) input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()) } else { - input.sparkContext.emptyRDD[LabeledPoint] + input.sparkContext.emptyRDD[Instance] } findSplitsBySorting(sampledInput, metadata, continuousFeatures) } private def findSplitsBySorting( - input: RDD[LabeledPoint], + input: RDD[Instance], metadata: DecisionTreeMetadata, continuousFeatures: IndexedSeq[Int]): Array[Array[Split]] = { @@ -917,7 +952,8 @@ private[spark] object RandomForest extends Logging with Serializable { input .flatMap { point => - continuousFeatures.map(idx => (idx, point.features(idx))).filter(_._2 != 0.0) + continuousFeatures.map(idx => (idx, (point.weight, point.features(idx)))) + .filter(_._2._2 != 0.0) }.groupByKey(numPartitions) .map { case (idx, samples) => val thresholds = findSplitsForContinuousFeature(samples, metadata, idx) @@ -982,7 +1018,7 @@ private[spark] object RandomForest extends Logging with Serializable { * could be different from the specified `numSplits`. * The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly. * - * @param featureSamples feature values of each sample + * @param featureSamples feature values and sample weights of each sample * @param metadata decision tree metadata * NOTE: `metadata.numbins` will be changed accordingly * if there are not enough splits to be found @@ -990,7 +1026,7 @@ private[spark] object RandomForest extends Logging with Serializable { * @return array of split thresholds */ private[tree] def findSplitsForContinuousFeature( - featureSamples: Iterable[Double], + featureSamples: Iterable[(Double, Double)], metadata: DecisionTreeMetadata, featureIndex: Int): Array[Double] = { require(metadata.isContinuous(featureIndex), @@ -1002,19 +1038,24 @@ private[spark] object RandomForest extends Logging with Serializable { val numSplits = metadata.numSplits(featureIndex) // get count for each distinct value except zero value - val partNumSamples = featureSamples.size - val partValueCountMap = scala.collection.mutable.Map[Double, Int]() - featureSamples.foreach { x => - partValueCountMap(x) = partValueCountMap.getOrElse(x, 0) + 1 + val partValueCountMap = mutable.Map[Double, Double]() + var partNumSamples = 0.0 + var unweightedNumSamples = 0.0 + featureSamples.foreach { case (sampleWeight, feature) => + partValueCountMap(feature) = partValueCountMap.getOrElse(feature, 0.0) + sampleWeight; + partNumSamples += sampleWeight; + unweightedNumSamples += 1.0 } // Calculate the expected number of samples for finding splits - val numSamples = (samplesFractionForFindSplits(metadata) * metadata.numExamples).toInt + val weightedNumSamples = samplesFractionForFindSplits(metadata) * + metadata.weightedNumExamples // add expected zero value count and get complete statistics - val valueCountMap: Map[Double, Int] = if (numSamples - partNumSamples > 0) { - partValueCountMap.toMap + (0.0 -> (numSamples - partNumSamples)) + val tolerance = Utils.EPSILON * unweightedNumSamples * unweightedNumSamples + val valueCountMap = if (weightedNumSamples - partNumSamples > tolerance) { + partValueCountMap + (0.0 -> (weightedNumSamples - partNumSamples)) } else { - partValueCountMap.toMap + partValueCountMap } // sort distinct values @@ -1031,7 +1072,7 @@ private[spark] object RandomForest extends Logging with Serializable { .toArray } else { // stride between splits - val stride: Double = numSamples.toDouble / (numSplits + 1) + val stride: Double = weightedNumSamples / (numSplits + 1) logDebug("stride = " + stride) // iterate `valueCount` to find splits diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala index a6ac64a0463c..72440b2c57aa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.tree.impl -import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.tree.{ContinuousSplit, Split} import org.apache.spark.rdd.RDD @@ -36,10 +36,12 @@ import org.apache.spark.rdd.RDD * @param label Label from LabeledPoint * @param binnedFeatures Binned feature values. * Same length as LabeledPoint.features, but values are bin indices. + * @param weight Sample weight for this TreePoint. */ -private[spark] class TreePoint(val label: Double, val binnedFeatures: Array[Int]) - extends Serializable { -} +private[spark] class TreePoint( + val label: Double, + val binnedFeatures: Array[Int], + val weight: Double) extends Serializable private[spark] object TreePoint { @@ -52,7 +54,7 @@ private[spark] object TreePoint { * @return TreePoint dataset representation */ def convertToTreeRDD( - input: RDD[LabeledPoint], + input: RDD[Instance], splits: Array[Array[Split]], metadata: DecisionTreeMetadata): RDD[TreePoint] = { // Construct arrays for featureArity for efficiency in the inner loop. @@ -82,18 +84,18 @@ private[spark] object TreePoint { * for categorical features. */ private def labeledPointToTreePoint( - labeledPoint: LabeledPoint, + instance: Instance, thresholds: Array[Array[Double]], featureArity: Array[Int]): TreePoint = { - val numFeatures = labeledPoint.features.size + val numFeatures = instance.features.size val arr = new Array[Int](numFeatures) var featureIndex = 0 while (featureIndex < numFeatures) { arr(featureIndex) = - findBin(featureIndex, labeledPoint, featureArity(featureIndex), thresholds(featureIndex)) + findBin(featureIndex, instance, featureArity(featureIndex), thresholds(featureIndex)) featureIndex += 1 } - new TreePoint(labeledPoint.label, arr) + new TreePoint(instance.label, arr, instance.weight) } /** @@ -106,10 +108,10 @@ private[spark] object TreePoint { */ private def findBin( featureIndex: Int, - labeledPoint: LabeledPoint, + instance: Instance, featureArity: Int, thresholds: Array[Double]): Int = { - val featureValue = labeledPoint.features(featureIndex) + val featureValue = instance.features(featureIndex) if (featureArity == 0) { val idx = java.util.Arrays.binarySearch(thresholds, featureValue) @@ -125,7 +127,7 @@ private[spark] object TreePoint { s"DecisionTree given invalid data:" + s" Feature $featureIndex is categorical with values in {0,...,${featureArity - 1}," + s" but a data point gives it value $featureValue.\n" + - " Bad data point: " + labeledPoint.toString) + s" Bad data point: $instance") } featureValue.toInt } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index 4aa4c3617e7f..51d5d5c58c57 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -282,6 +282,7 @@ private[ml] object DecisionTreeModelReadWrite { * * @param id Index used for tree reconstruction. Indices follow a pre-order traversal. * @param impurityStats Stats array. Impurity type is stored in metadata. + * @param rawCount The unweighted number of samples falling in this node. * @param gain Gain, or arbitrary value if leaf node. * @param leftChild Left child index, or arbitrary value if leaf node. * @param rightChild Right child index, or arbitrary value if leaf node. @@ -292,6 +293,7 @@ private[ml] object DecisionTreeModelReadWrite { prediction: Double, impurity: Double, impurityStats: Array[Double], + rawCount: Long, gain: Double, leftChild: Int, rightChild: Int, @@ -311,11 +313,12 @@ private[ml] object DecisionTreeModelReadWrite { val (leftNodeData, leftIdx) = build(n.leftChild, id + 1) val (rightNodeData, rightIdx) = build(n.rightChild, leftIdx + 1) val thisNodeData = NodeData(id, n.prediction, n.impurity, n.impurityStats.stats, - n.gain, leftNodeData.head.id, rightNodeData.head.id, SplitData(n.split)) + n.impurityStats.rawCount, n.gain, leftNodeData.head.id, rightNodeData.head.id, + SplitData(n.split)) (thisNodeData +: (leftNodeData ++ rightNodeData), rightIdx) case _: LeafNode => (Seq(NodeData(id, node.prediction, node.impurity, node.impurityStats.stats, - -1.0, -1, -1, SplitData(-1, Array.empty[Double], -1))), + node.impurityStats.rawCount, -1.0, -1, -1, SplitData(-1, Array.empty[Double], -1))), id) } } @@ -360,7 +363,8 @@ private[ml] object DecisionTreeModelReadWrite { // traversal, this guarantees that child nodes will be built before parent nodes. val finalNodes = new Array[Node](nodes.length) nodes.reverseIterator.foreach { case n: NodeData => - val impurityStats = ImpurityCalculator.getCalculator(impurityType, n.impurityStats) + val impurityStats = + ImpurityCalculator.getCalculator(impurityType, n.impurityStats, n.rawCount) val node = if (n.leftChild != -1) { val leftChild = finalNodes(n.leftChild) val rightChild = finalNodes(n.rightChild) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index c06c68d44ae1..df01dc007882 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} * Note: Marked as private and DeveloperApi since this may be made public in the future. */ private[ml] trait DecisionTreeParams extends PredictorParams - with HasCheckpointInterval with HasSeed { + with HasCheckpointInterval with HasSeed with HasWeightCol { /** * Maximum depth of the tree (>= 0). @@ -74,6 +74,21 @@ private[ml] trait DecisionTreeParams extends PredictorParams " child to have fewer than minInstancesPerNode, the split will be discarded as invalid." + " Should be >= 1.", ParamValidators.gtEq(1)) + /** + * Minimum fraction of the weighted sample count that each child must have after split. + * If a split causes the fraction of the total weight in the left or right child to be less than + * minWeightFractionPerNode, the split will be discarded as invalid. + * Should be in the interval [0.0, 0.5). + * (default = 0.0) + * @group param + */ + final val minWeightFractionPerNode: DoubleParam = new DoubleParam(this, + "minWeightFractionPerNode", "Minimum fraction of the weighted sample count that each child " + + "must have after split. If a split causes the fraction of the total weight in the left or " + + "right child to be less than minWeightFractionPerNode, the split will be discarded as " + + "invalid. Should be in interval [0.0, 0.5)", + ParamValidators.inRange(0.0, 0.5, lowerInclusive = true, upperInclusive = false)) + /** * Minimum information gain for a split to be considered at a tree node. * Should be >= 0.0. @@ -107,8 +122,9 @@ private[ml] trait DecisionTreeParams extends PredictorParams " algorithm will cache node IDs for each instance. Caching can speed up training of deeper" + " trees.") - setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0, - maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10) + setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, + minWeightFractionPerNode -> 0.0, minInfoGain -> 0.0, maxMemoryInMB -> 256, + cacheNodeIds -> false, checkpointInterval -> 10) /** @group getParam */ final def getMaxDepth: Int = $(maxDepth) @@ -119,6 +135,9 @@ private[ml] trait DecisionTreeParams extends PredictorParams /** @group getParam */ final def getMinInstancesPerNode: Int = $(minInstancesPerNode) + /** @group getParam */ + final def getMinWeightFractionPerNode: Double = $(minWeightFractionPerNode) + /** @group getParam */ final def getMinInfoGain: Double = $(minInfoGain) @@ -143,6 +162,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams strategy.maxMemoryInMB = getMaxMemoryInMB strategy.minInfoGain = getMinInfoGain strategy.minInstancesPerNode = getMinInstancesPerNode + strategy.minWeightFractionPerNode = getMinWeightFractionPerNode strategy.useNodeIdCache = getCacheNodeIds strategy.numClasses = numClasses strategy.categoricalFeaturesInfo = categoricalFeatures diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index a8c5286f3dc1..94224be80752 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -23,6 +23,7 @@ import scala.util.Try import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging +import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.tree.{DecisionTreeModel => NewDTModel, TreeEnsembleParams => NewRFParams} import org.apache.spark.ml.tree.impl.{RandomForest => NewRandomForest} import org.apache.spark.mllib.regression.LabeledPoint @@ -91,8 +92,8 @@ private class RandomForest ( * @return RandomForestModel that can be used for prediction. */ def run(input: RDD[LabeledPoint]): RandomForestModel = { - val trees: Array[NewDTModel] = NewRandomForest.run(input.map(_.asML), strategy, numTrees, - featureSubsetStrategy, seed.toLong, None) + val trees: Array[NewDTModel] = + NewRandomForest.run(input, strategy, numTrees, featureSubsetStrategy, seed.toLong) new RandomForestModel(strategy.algo, trees.map(_.toOld)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 58e8f5be7b9f..d9dcb8001340 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -80,7 +80,8 @@ class Strategy @Since("1.3.0") ( @Since("1.0.0") @BeanProperty var maxMemoryInMB: Int = 256, @Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1, @Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false, - @Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10) extends Serializable { + @Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10, + @Since("3.0.0") @BeanProperty var minWeightFractionPerNode: Double = 0.0) extends Serializable { /** */ @@ -96,6 +97,31 @@ class Strategy @Since("1.3.0") ( isMulticlassClassification && (categoricalFeaturesInfo.size > 0) } + // scalastyle:off argcount + /** + * Backwards compatible constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]] + */ + @Since("1.0.0") + def this( + algo: Algo, + impurity: Impurity, + maxDepth: Int, + numClasses: Int, + maxBins: Int, + quantileCalculationStrategy: QuantileStrategy, + categoricalFeaturesInfo: Map[Int, Int], + minInstancesPerNode: Int, + minInfoGain: Double, + maxMemoryInMB: Int, + subsamplingRate: Double, + useNodeIdCache: Boolean, + checkpointInterval: Int) { + this(algo, impurity, maxDepth, numClasses, maxBins, quantileCalculationStrategy, + categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, maxMemoryInMB, + subsamplingRate, useNodeIdCache, checkpointInterval, 0.0) + } + // scalastyle:on argcount + /** * Java-friendly constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]] */ @@ -108,7 +134,8 @@ class Strategy @Since("1.3.0") ( maxBins: Int, categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]) { this(algo, impurity, maxDepth, numClasses, maxBins, Sort, - categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap) + categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, + minWeightFractionPerNode = 0.0) } /** @@ -171,8 +198,9 @@ class Strategy @Since("1.3.0") ( @Since("1.2.0") def copy: Strategy = { new Strategy(algo, impurity, maxDepth, numClasses, maxBins, - quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, - maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointInterval) + quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, + minInfoGain, maxMemoryInMB, subsamplingRate, useNodeIdCache, + checkpointInterval, minWeightFractionPerNode) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index d4448da9eef5..f01a98e74886 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -83,23 +83,29 @@ object Entropy extends Impurity { * @param numClasses Number of classes for label. */ private[spark] class EntropyAggregator(numClasses: Int) - extends ImpurityAggregator(numClasses) with Serializable { + extends ImpurityAggregator(numClasses + 1) with Serializable { /** * Update stats for one (node, feature, bin) with the given label. * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. * @param offset Start index of stats for this (node, feature, bin). */ - def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = { - if (label >= statsSize) { + def update( + allStats: Array[Double], + offset: Int, + label: Double, + numSamples: Int, + sampleWeight: Double): Unit = { + if (label >= numClasses) { throw new IllegalArgumentException(s"EntropyAggregator given label $label" + - s" but requires label < numClasses (= $statsSize).") + s" but requires label < numClasses (= ${numClasses}).") } if (label < 0) { throw new IllegalArgumentException(s"EntropyAggregator given label $label" + s"but requires label is non-negative.") } - allStats(offset + label.toInt) += instanceWeight + allStats(offset + label.toInt) += numSamples * sampleWeight + allStats(offset + statsSize - 1) += numSamples } /** @@ -108,7 +114,8 @@ private[spark] class EntropyAggregator(numClasses: Int) * @param offset Start index of stats for this (node, feature, bin). */ def getCalculator(allStats: Array[Double], offset: Int): EntropyCalculator = { - new EntropyCalculator(allStats.view(offset, offset + statsSize).toArray) + new EntropyCalculator(allStats.view(offset, offset + statsSize - 1).toArray, + allStats(offset + statsSize - 1).toLong) } } @@ -118,12 +125,13 @@ private[spark] class EntropyAggregator(numClasses: Int) * (node, feature, bin). * @param stats Array of sufficient statistics for a (node, feature, bin). */ -private[spark] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { +private[spark] class EntropyCalculator(stats: Array[Double], var rawCount: Long) + extends ImpurityCalculator(stats) { /** * Make a deep copy of this [[ImpurityCalculator]]. */ - def copy: EntropyCalculator = new EntropyCalculator(stats.clone()) + def copy: EntropyCalculator = new EntropyCalculator(stats.clone(), rawCount) /** * Calculate the impurity from the stored sufficient statistics. @@ -131,9 +139,9 @@ private[spark] class EntropyCalculator(stats: Array[Double]) extends ImpurityCal def calculate(): Double = Entropy.calculate(stats, stats.sum) /** - * Number of data points accounted for in the sufficient statistics. + * Weighted number of data points accounted for in the sufficient statistics. */ - def count: Long = stats.sum.toLong + def count: Double = stats.sum /** * Prediction which should be made based on the sufficient statistics. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index c5e34ffa4f2e..913ffbbb2457 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -80,23 +80,29 @@ object Gini extends Impurity { * @param numClasses Number of classes for label. */ private[spark] class GiniAggregator(numClasses: Int) - extends ImpurityAggregator(numClasses) with Serializable { + extends ImpurityAggregator(numClasses + 1) with Serializable { /** * Update stats for one (node, feature, bin) with the given label. * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. * @param offset Start index of stats for this (node, feature, bin). */ - def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = { - if (label >= statsSize) { + def update( + allStats: Array[Double], + offset: Int, + label: Double, + numSamples: Int, + sampleWeight: Double): Unit = { + if (label >= numClasses) { throw new IllegalArgumentException(s"GiniAggregator given label $label" + - s" but requires label < numClasses (= $statsSize).") + s" but requires label < numClasses (= ${numClasses}).") } if (label < 0) { throw new IllegalArgumentException(s"GiniAggregator given label $label" + - s"but requires label is non-negative.") + s"but requires label to be non-negative.") } - allStats(offset + label.toInt) += instanceWeight + allStats(offset + label.toInt) += numSamples * sampleWeight + allStats(offset + statsSize - 1) += numSamples } /** @@ -105,7 +111,8 @@ private[spark] class GiniAggregator(numClasses: Int) * @param offset Start index of stats for this (node, feature, bin). */ def getCalculator(allStats: Array[Double], offset: Int): GiniCalculator = { - new GiniCalculator(allStats.view(offset, offset + statsSize).toArray) + new GiniCalculator(allStats.view(offset, offset + statsSize - 1).toArray, + allStats(offset + statsSize - 1).toLong) } } @@ -115,12 +122,13 @@ private[spark] class GiniAggregator(numClasses: Int) * (node, feature, bin). * @param stats Array of sufficient statistics for a (node, feature, bin). */ -private[spark] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { +private[spark] class GiniCalculator(stats: Array[Double], var rawCount: Long) + extends ImpurityCalculator(stats) { /** * Make a deep copy of this [[ImpurityCalculator]]. */ - def copy: GiniCalculator = new GiniCalculator(stats.clone()) + def copy: GiniCalculator = new GiniCalculator(stats.clone(), rawCount) /** * Calculate the impurity from the stored sufficient statistics. @@ -128,9 +136,9 @@ private[spark] class GiniCalculator(stats: Array[Double]) extends ImpurityCalcul def calculate(): Double = Gini.calculate(stats, stats.sum) /** - * Number of data points accounted for in the sufficient statistics. + * Weighted number of data points accounted for in the sufficient statistics. */ - def count: Long = stats.sum.toLong + def count: Double = stats.sum /** * Prediction which should be made based on the sufficient statistics. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index f151a6a01b65..491473490eba 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -81,7 +81,12 @@ private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Ser * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. * @param offset Start index of stats for this (node, feature, bin). */ - def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit + def update( + allStats: Array[Double], + offset: Int, + label: Double, + numSamples: Int, + sampleWeight: Double): Unit /** * Get an [[ImpurityCalculator]] for a (node, feature, bin). @@ -122,6 +127,7 @@ private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) exten stats(i) += other.stats(i) i += 1 } + rawCount += other.rawCount this } @@ -139,13 +145,19 @@ private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) exten stats(i) -= other.stats(i) i += 1 } + rawCount -= other.rawCount this } /** - * Number of data points accounted for in the sufficient statistics. + * Weighted number of data points accounted for in the sufficient statistics. */ - def count: Long + def count: Double + + /** + * Raw number of data points accounted for in the sufficient statistics. + */ + var rawCount: Long /** * Prediction which should be made based on the sufficient statistics. @@ -185,11 +197,14 @@ private[spark] object ImpurityCalculator { * Create an [[ImpurityCalculator]] instance of the given impurity type and with * the given stats. */ - def getCalculator(impurity: String, stats: Array[Double]): ImpurityCalculator = { + def getCalculator( + impurity: String, + stats: Array[Double], + rawCount: Long): ImpurityCalculator = { impurity.toLowerCase(Locale.ROOT) match { - case "gini" => new GiniCalculator(stats) - case "entropy" => new EntropyCalculator(stats) - case "variance" => new VarianceCalculator(stats) + case "gini" => new GiniCalculator(stats, rawCount) + case "entropy" => new EntropyCalculator(stats, rawCount) + case "variance" => new VarianceCalculator(stats, rawCount) case _ => throw new IllegalArgumentException( s"ImpurityCalculator builder did not recognize impurity type: $impurity") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index c9bf0db4de3c..a07b919271f7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -66,21 +66,32 @@ object Variance extends Impurity { /** * Class for updating views of a vector of sufficient statistics, - * in order to compute impurity from a sample. + * in order to compute impurity from a sample. For variance, we track: + * - sum(w_i) + * - sum(w_i * y_i) + * - sum(w_i * y_i * y_i) + * - count(y_i) * Note: Instances of this class do not hold the data; they operate on views of the data. */ private[spark] class VarianceAggregator() - extends ImpurityAggregator(statsSize = 3) with Serializable { + extends ImpurityAggregator(statsSize = 4) with Serializable { /** * Update stats for one (node, feature, bin) with the given label. * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. * @param offset Start index of stats for this (node, feature, bin). */ - def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = { + def update( + allStats: Array[Double], + offset: Int, + label: Double, + numSamples: Int, + sampleWeight: Double): Unit = { + val instanceWeight = numSamples * sampleWeight allStats(offset) += instanceWeight allStats(offset + 1) += instanceWeight * label allStats(offset + 2) += instanceWeight * label * label + allStats(offset + 3) += numSamples } /** @@ -89,7 +100,8 @@ private[spark] class VarianceAggregator() * @param offset Start index of stats for this (node, feature, bin). */ def getCalculator(allStats: Array[Double], offset: Int): VarianceCalculator = { - new VarianceCalculator(allStats.view(offset, offset + statsSize).toArray) + new VarianceCalculator(allStats.view(offset, offset + statsSize - 1).toArray, + allStats(offset + statsSize - 1).toLong) } } @@ -99,7 +111,8 @@ private[spark] class VarianceAggregator() * (node, feature, bin). * @param stats Array of sufficient statistics for a (node, feature, bin). */ -private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { +private[spark] class VarianceCalculator(stats: Array[Double], var rawCount: Long) + extends ImpurityCalculator(stats) { require(stats.length == 3, s"VarianceCalculator requires sufficient statistics array stats to be of length 3," + @@ -108,7 +121,7 @@ private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCa /** * Make a deep copy of this [[ImpurityCalculator]]. */ - def copy: VarianceCalculator = new VarianceCalculator(stats.clone()) + def copy: VarianceCalculator = new VarianceCalculator(stats.clone(), rawCount) /** * Calculate the impurity from the stored sufficient statistics. @@ -116,9 +129,9 @@ private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCa def calculate(): Double = Variance.calculate(stats(0), stats(1), stats(2)) /** - * Number of data points accounted for in the sufficient statistics. + * Weighted number of data points accounted for in the sufficient statistics. */ - def count: Long = stats(0).toLong + def count: Double = stats(0) /** * Prediction which should be made based on the sufficient statistics. diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 2930f4900d50..433c78237024 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -42,6 +42,8 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { private var continuousDataPointsForMulticlassRDD: RDD[LabeledPoint] = _ private var categoricalDataPointsForMulticlassForOrderedFeaturesRDD: RDD[LabeledPoint] = _ + private val seed = 42 + override def beforeAll() { super.beforeAll() categoricalDataPointsRDD = @@ -250,7 +252,7 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { MLTestingUtils.checkCopyAndUids(dt, newTree) - testTransformer[(Vector, Double)](newData, newTree, + testTransformer[(Vector, Double, Double)](newData, newTree, "prediction", "rawPrediction", "probability") { case Row(pred: Double, rawPred: Vector, probPred: Vector) => assert(pred === rawPred.argmax, @@ -327,6 +329,49 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { dt.fit(df) } + test("training with sample weights") { + val df = { + val nPoints = 100 + val coefficients = Array( + -0.57997, 0.912083, -0.371077, + -0.16624, -0.84355, -0.048509) + + val xMean = Array(5.843, 3.057) + val xVariance = Array(0.6856, 0.1899) + + val testData = LogisticRegressionSuite.generateMultinomialLogisticInput( + coefficients, xMean, xVariance, addIntercept = true, nPoints, seed) + + sc.parallelize(testData, 4).toDF() + } + val numClasses = 3 + val predEquals = (x: Double, y: Double) => x == y + // (impurity, maxDepth) + val testParams = Seq( + ("gini", 10), + ("entropy", 10), + ("gini", 5) + ) + for ((impurity, maxDepth) <- testParams) { + val estimator = new DecisionTreeClassifier() + .setMaxDepth(maxDepth) + .setSeed(seed) + .setMinWeightFractionPerNode(0.049) + .setImpurity(impurity) + + MLTestingUtils.testArbitrarilyScaledWeights[DecisionTreeClassificationModel, + DecisionTreeClassifier](df.as[LabeledPoint], estimator, + MLTestingUtils.modelPredictionEquals(df, predEquals, 0.7)) + MLTestingUtils.testOutliersWithSmallWeights[DecisionTreeClassificationModel, + DecisionTreeClassifier](df.as[LabeledPoint], estimator, + numClasses, MLTestingUtils.modelPredictionEquals(df, predEquals, 0.8), + outlierRatio = 2) + MLTestingUtils.testOversamplingVsWeighting[DecisionTreeClassificationModel, + DecisionTreeClassifier](df.as[LabeledPoint], estimator, + MLTestingUtils.modelPredictionEquals(df, predEquals, 0.7), seed) + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// @@ -350,7 +395,6 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 2) testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings, allParamSettings, checkModelData) - // Continuous splits with tree depth 2 val continuousData: DataFrame = TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index ba4a9cf08278..027583ffc60b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.tree.LeafNode @@ -141,7 +141,7 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest { MLTestingUtils.checkCopyAndUids(rf, model) - testTransformer[(Vector, Double)](df, model, "prediction", "rawPrediction", + testTransformer[(Vector, Double, Double)](df, model, "prediction", "rawPrediction", "probability") { case Row(pred: Double, rawPred: Vector, probPred: Vector) => assert(pred === rawPred.argmax, s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.") @@ -180,7 +180,6 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest { ///////////////////////////////////////////////////////////////////////////// // Tests of feature importance ///////////////////////////////////////////////////////////////////////////// - test("Feature importance with toy data") { val numClasses = 2 val rf = new RandomForestClassifier() diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 29a438396516..38bfa6626e3c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} @@ -26,6 +26,7 @@ import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} +import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} @@ -35,11 +36,17 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ private var categoricalDataPointsRDD: RDD[LabeledPoint] = _ + private var linearRegressionData: DataFrame = _ + + private val seed = 42 override def beforeAll() { super.beforeAll() categoricalDataPointsRDD = sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints().map(_.asML)) + linearRegressionData = sc.parallelize(LinearDataGenerator.generateLinearInput( + intercept = 6.3, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3), + xVariance = Array(0.7, 1.2), nPoints = 1000, seed, eps = 0.5), 2).map(_.asML).toDF() } ///////////////////////////////////////////////////////////////////////////// @@ -88,7 +95,7 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest { val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0) val model = dt.fit(df) - testTransformer[(Vector, Double)](df, model, "features", "variance") { + testTransformer[(Vector, Double, Double)](df, model, "features", "variance") { case Row(features: Vector, variance: Double) => val expectedVariance = model.rootNode.predictImpl(features).impurityStats.calculate() assert(variance === expectedVariance, @@ -101,7 +108,7 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest { .setMaxBins(6) .setSeed(0) - testTransformerByGlobalCheckFunc[(Vector, Double)](varianceDF, dt.fit(varianceDF), + testTransformerByGlobalCheckFunc[(Vector, Double, Double)](varianceDF, dt.fit(varianceDF), "variance") { case rows: Seq[Row] => val calculatedVariances = rows.map(_.getDouble(0)) @@ -159,6 +166,28 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest { } } + test("training with sample weights") { + val df = linearRegressionData + val numClasses = 0 + val testParams = Seq(5, 10) + for (maxDepth <- testParams) { + val estimator = new DecisionTreeRegressor() + .setMaxDepth(maxDepth) + .setMinWeightFractionPerNode(0.05) + .setSeed(123) + MLTestingUtils.testArbitrarilyScaledWeights[DecisionTreeRegressionModel, + DecisionTreeRegressor](df.as[LabeledPoint], estimator, + MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.1, 0.99)) + MLTestingUtils.testOutliersWithSmallWeights[DecisionTreeRegressionModel, + DecisionTreeRegressor](df.as[LabeledPoint], estimator, numClasses, + MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.1, 0.99), + outlierRatio = 2) + MLTestingUtils.testOversamplingVsWeighting[DecisionTreeRegressionModel, + DecisionTreeRegressor](df.as[LabeledPoint], estimator, + MLTestingUtils.modelPredictionEquals(df, _ ~= _ relTol 0.01, 1.0), seed) + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 90ceb7dee38f..76532897dff2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -891,6 +891,7 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLRe .setStandardization(standardization) .setRegParam(regParam) .setElasticNetParam(elasticNetParam) + .setSolver(solver) MLTestingUtils.testArbitrarilyScaledWeights[LinearRegressionModel, LinearRegression]( datasetWithStrongNoise.as[LabeledPoint], estimator, modelEquals) MLTestingUtils.testOutliersWithSmallWeights[LinearRegressionModel, LinearRegression]( diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala index 77ab3d8bb75f..63985482795b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.tree.impl import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.mllib.tree.EnsembleTestHelper import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -26,12 +27,16 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext */ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext { - test("BaggedPoint RDD: without subsampling") { - val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) + test("BaggedPoint RDD: without subsampling with weights") { + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000).map { lp => + Instance(lp.label, 0.5, lp.features.asML) + } val rdd = sc.parallelize(arr) - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false, 42) + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false, + (instance: Instance) => instance.weight * 4.0, seed = 42) baggedRDD.collect().foreach { baggedPoint => - assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1) + assert(baggedPoint.subsampleCounts.size === 1 && baggedPoint.subsampleCounts(0) === 1) + assert(baggedPoint.sampleWeight === 2.0) } } @@ -40,13 +45,17 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext { val (expectedMean, expectedStddev) = (1.0, 1.0) val seeds = Array(123, 5354, 230, 349867, 23987) - val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000).map(_.asML) val rdd = sc.parallelize(arr) seeds.foreach { seed => - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true, seed) - val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true, + (_: LabeledPoint) => 2.0, seed) + val subsampleCounts: Array[Array[Double]] = + baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect() EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, expectedStddev, epsilon = 0.01) + // should ignore weight function for now + assert(baggedRDD.collect().forall(_.sampleWeight === 1.0)) } } @@ -59,8 +68,10 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) val rdd = sc.parallelize(arr) seeds.foreach { seed => - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true, seed) - val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() + val baggedRDD = + BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true, seed = seed) + val subsampleCounts: Array[Array[Double]] = + baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect() EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, expectedStddev, epsilon = 0.01) } @@ -71,13 +82,17 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext { val (expectedMean, expectedStddev) = (1.0, 0) val seeds = Array(123, 5354, 230, 349867, 23987) - val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000).map(_.asML) val rdd = sc.parallelize(arr) seeds.foreach { seed => - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false, seed) - val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false, + (_: LabeledPoint) => 2.0, seed) + val subsampleCounts: Array[Array[Double]] = + baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect() EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, expectedStddev, epsilon = 0.01) + // should ignore weight function for now + assert(baggedRDD.collect().forall(_.sampleWeight === 1.0)) } } @@ -90,8 +105,10 @@ class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) val rdd = sc.parallelize(arr) seeds.foreach { seed => - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false, seed) - val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false, + seed = seed) + val subsampleCounts: Array[Array[Double]] = + baggedRDD.map(_.subsampleCounts.map(_.toDouble)).collect() EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, expectedStddev, epsilon = 0.01) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 5caa5117d575..b89cc6053b06 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -19,10 +19,11 @@ package org.apache.spark.ml.tree.impl import scala.annotation.tailrec import scala.collection.mutable +import scala.language.implicitConversions import org.apache.spark.SparkFunSuite import org.apache.spark.ml.classification.DecisionTreeClassificationModel -import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.tree._ import org.apache.spark.ml.util.TestingUtils._ @@ -46,7 +47,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { ///////////////////////////////////////////////////////////////////////////// test("Binary classification with continuous features: split calculation") { - val arr = OldDTSuite.generateOrderedLabeledPointsWithLabel1().map(_.asML) + val arr = OldDTSuite.generateOrderedLabeledPointsWithLabel1().map(_.asML.toInstance) assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2, 100) @@ -58,7 +59,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { } test("Binary classification with binary (ordered) categorical features: split calculation") { - val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML) + val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML.toInstance) assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 2, @@ -75,7 +76,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { test("Binary classification with 3-ary (ordered) categorical features," + " with no samples for one category: split calculation") { - val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML) + val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML.toInstance) assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 2, @@ -93,12 +94,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { test("find splits for a continuous feature") { // find splits for normal case { - val fakeMetadata = new DecisionTreeMetadata(1, 200000, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 200000, 200000.0, 0, 0, Map(), Set(), Array(6), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0 + 0, 0, 0.0, 0.0, 0, 0 ) - val featureSamples = Array.fill(10000)(math.random).filter(_ != 0.0) + val featureSamples = Array.fill(10000)((1.0, math.random)).filter(_._2 != 0.0) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) assert(splits.length === 5) assert(fakeMetadata.numSplits(0) === 5) @@ -109,15 +110,16 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // SPARK-16957: Use midpoints for split values. { - val fakeMetadata = new DecisionTreeMetadata(1, 8, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 8, 8.0, 0, 0, Map(), Set(), Array(3), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0 + 0, 0, 0.0, 0.0, 0, 0 ) // possibleSplits <= numSplits { - val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1).map(_.toDouble).filter(_ != 0.0) + val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1) + .map(x => (1.0, x.toDouble)).filter(_._2 != 0.0) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) val expectedSplits = Array((0.0 + 1.0) / 2) assert(splits === expectedSplits) @@ -125,7 +127,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // possibleSplits > numSplits { - val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3).map(_.toDouble).filter(_ != 0.0) + val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3) + .map(x => (1.0, x.toDouble)).filter(_._2 != 0.0) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) val expectedSplits = Array((0.0 + 1.0) / 2, (2.0 + 3.0) / 2) assert(splits === expectedSplits) @@ -135,12 +138,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits should not return identical splits // when there are not enough split candidates, reduce the number of splits in metadata { - val fakeMetadata = new DecisionTreeMetadata(1, 12, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 12, 12.0, 0, 0, Map(), Set(), Array(5), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0 + 0, 0, 0.0, 0.0, 0, 0 ) - val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3).map(_.toDouble) + val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3).map(x => (1.0, x.toDouble)) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) val expectedSplits = Array((1.0 + 2.0) / 2, (2.0 + 3.0) / 2) assert(splits === expectedSplits) @@ -150,13 +153,13 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits when most samples close to the minimum { - val fakeMetadata = new DecisionTreeMetadata(1, 18, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 18, 18.0, 0, 0, Map(), Set(), Array(3), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0 + 0, 0, 0.0, 0.0, 0, 0 ) - val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5) - .map(_.toDouble) + val featureSamples = + Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(x => (1.0, x.toDouble)) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) val expectedSplits = Array((2.0 + 3.0) / 2, (3.0 + 4.0) / 2) assert(splits === expectedSplits) @@ -164,37 +167,55 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // find splits when most samples close to the maximum { - val fakeMetadata = new DecisionTreeMetadata(1, 17, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 17, 17.0, 0, 0, Map(), Set(), Array(2), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0 + 0, 0, 0.0, 0.0, 0, 0 ) - val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2) - .map(_.toDouble).filter(_ != 0.0) + val featureSamples = + Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(x => (1.0, x.toDouble)) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) val expectedSplits = Array((1.0 + 2.0) / 2) assert(splits === expectedSplits) } - // find splits for constant feature + // find splits for arbitrarily scaled data { - val fakeMetadata = new DecisionTreeMetadata(1, 3, 0, 0, + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0.0, 0, 0, + Map(), Set(), + Array(6), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0.0, 0, 0 + ) + val featureSamplesUnitWeight = Array.fill(10)((1.0, math.random)) + val featureSamplesSmallWeight = featureSamplesUnitWeight.map { case (w, x) => (w * 0.001, x)} + val featureSamplesLargeWeight = featureSamplesUnitWeight.map { case (w, x) => (w * 1000, x)} + val splitsUnitWeight = RandomForest + .findSplitsForContinuousFeature(featureSamplesUnitWeight, fakeMetadata, 0) + val splitsSmallWeight = RandomForest + .findSplitsForContinuousFeature(featureSamplesSmallWeight, fakeMetadata, 0) + val splitsLargeWeight = RandomForest + .findSplitsForContinuousFeature(featureSamplesLargeWeight, fakeMetadata, 0) + assert(splitsUnitWeight === splitsSmallWeight) + assert(splitsUnitWeight === splitsLargeWeight) + } + + // find splits when most weight is close to the minimum + { + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0.0, 0, 0, Map(), Set(), Array(3), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0 + 0, 0, 0.0, 0.0, 0, 0 ) - val featureSamples = Array(0, 0, 0).map(_.toDouble).filter(_ != 0.0) - val featureSamplesEmpty = Array.empty[Double] + val featureSamples = Array((10, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6)).map { + case (w, x) => (w.toDouble, x.toDouble) + } val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) - assert(splits === Array.empty[Double]) - val splitsEmpty = - RandomForest.findSplitsForContinuousFeature(featureSamplesEmpty, fakeMetadata, 0) - assert(splitsEmpty === Array.empty[Double]) + assert(splits === Array(1.5, 2.5, 3.5, 4.5, 5.5)) } } test("train with empty arrays") { - val lp = LabeledPoint(1.0, Vectors.dense(Array.empty[Double])) + val lp = LabeledPoint(1.0, Vectors.dense(Array.empty[Double])).toInstance val data = Array.fill(5)(lp) val rdd = sc.parallelize(data) @@ -209,8 +230,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { } test("train with constant features") { - val lp = LabeledPoint(1.0, Vectors.dense(0.0, 0.0, 0.0)) - val data = Array.fill(5)(lp) + val instance = LabeledPoint(1.0, Vectors.dense(0.0, 0.0, 0.0)).toInstance + val data = Array.fill(5)(instance) val rdd = sc.parallelize(data) val strategy = new OldStrategy( OldAlgo.Classification, @@ -222,7 +243,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val Array(tree) = RandomForest.run(rdd, strategy, 1, "all", 42L, instr = None) assert(tree.rootNode.impurity === -1.0) assert(tree.depth === 0) - assert(tree.rootNode.prediction === lp.label) + assert(tree.rootNode.prediction === instance.label) // Test with no categorical features val strategy2 = new OldStrategy( @@ -233,11 +254,11 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val Array(tree2) = RandomForest.run(rdd, strategy2, 1, "all", 42L, instr = None) assert(tree2.rootNode.impurity === -1.0) assert(tree2.depth === 0) - assert(tree2.rootNode.prediction === lp.label) + assert(tree2.rootNode.prediction === instance.label) } test("Multiclass classification with unordered categorical features: split calculations") { - val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML) + val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML.toInstance) assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new OldStrategy( @@ -278,7 +299,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { } test("Multiclass classification with ordered categorical features: split calculations") { - val arr = OldDTSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures().map(_.asML) + val arr = OldDTSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() + .map(_.asML.toInstance) assert(arr.length === 3000) val rdd = sc.parallelize(arr) val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 100, @@ -310,7 +332,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)), LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)), LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))) - val input = sc.parallelize(arr) + val input = sc.parallelize(arr.map(_.toInstance)) val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 1, numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) @@ -352,7 +374,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)), LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)), LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))) - val input = sc.parallelize(arr) + val input = sc.parallelize(arr.map(_.toInstance)) val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 5, numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) @@ -404,7 +426,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { LabeledPoint(0.0, Vectors.dense(2.0)), LabeledPoint(0.0, Vectors.dense(2.0)), LabeledPoint(1.0, Vectors.dense(2.0))) - val input = sc.parallelize(arr) + val input = sc.parallelize(arr.map(_.toInstance)) // Must set maxBins s.t. the feature will be treated as an ordered categorical feature. val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 1, @@ -424,7 +446,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { } test("Second level node building with vs. without groups") { - val arr = OldDTSuite.generateOrderedLabeledPoints().map(_.asML) + val arr = OldDTSuite.generateOrderedLabeledPoints().map(_.asML.toInstance) assert(arr.length === 1000) val rdd = sc.parallelize(arr) // For tree with 1 group @@ -468,7 +490,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { def binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy: OldStrategy) { val numFeatures = 50 val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures, 1000) - val rdd = sc.parallelize(arr).map(_.asML) + val rdd = sc.parallelize(arr).map(_.asML.toInstance) // Select feature subset for top nodes. Return true if OK. def checkFeatureSubsetStrategy( @@ -581,16 +603,16 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { left2 parent left right */ - val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0)) + val leftImp = new GiniCalculator(Array(3.0, 2.0, 1.0), 6L) val left = new LeafNode(0.0, leftImp.calculate(), leftImp) - val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0)) + val rightImp = new GiniCalculator(Array(1.0, 2.0, 5.0), 8L) val right = new LeafNode(2.0, rightImp.calculate(), rightImp) val parent = TreeTests.buildParentNode(left, right, new ContinuousSplit(0, 0.5)) val parentImp = parent.impurityStats - val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0)) + val left2Imp = new GiniCalculator(Array(1.0, 6.0, 1.0), 8L) val left2 = new LeafNode(0.0, left2Imp.calculate(), left2Imp) val grandParent = TreeTests.buildParentNode(left2, parent, new ContinuousSplit(1, 1.0)) @@ -647,12 +669,12 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // feature_0 = 0 improves the impurity measure, despite the prediction will always be 0 // in both branches. val arr = Array( - LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), - LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), - LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), - LabeledPoint(1.0, Vectors.dense(1.0, 0.0)), - LabeledPoint(0.0, Vectors.dense(1.0, 0.0)), - LabeledPoint(1.0, Vectors.dense(1.0, 1.0)) + Instance(0.0, 1.0, Vectors.dense(0.0, 1.0)), + Instance(1.0, 1.0, Vectors.dense(0.0, 1.0)), + Instance(0.0, 1.0, Vectors.dense(0.0, 0.0)), + Instance(1.0, 1.0, Vectors.dense(1.0, 0.0)), + Instance(0.0, 1.0, Vectors.dense(1.0, 0.0)), + Instance(1.0, 1.0, Vectors.dense(1.0, 1.0)) ) val rdd = sc.parallelize(arr) @@ -677,13 +699,13 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { // feature_1 = 1 improves the impurity measure, despite the prediction will always be 0.5 // in both branches. val arr = Array( - LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), - LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), - LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), - LabeledPoint(0.0, Vectors.dense(1.0, 0.0)), - LabeledPoint(1.0, Vectors.dense(1.0, 1.0)), - LabeledPoint(0.0, Vectors.dense(1.0, 1.0)), - LabeledPoint(0.5, Vectors.dense(1.0, 1.0)) + Instance(0.0, 1.0, Vectors.dense(0.0, 1.0)), + Instance(1.0, 1.0, Vectors.dense(0.0, 1.0)), + Instance(0.0, 1.0, Vectors.dense(0.0, 0.0)), + Instance(0.0, 1.0, Vectors.dense(1.0, 0.0)), + Instance(1.0, 1.0, Vectors.dense(1.0, 1.0)), + Instance(0.0, 1.0, Vectors.dense(1.0, 1.0)), + Instance(0.5, 1.0, Vectors.dense(1.0, 1.0)) ) val rdd = sc.parallelize(arr) @@ -700,6 +722,56 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(unprunedTree.numNodes === 5) assert(RandomForestSuite.getSumLeafCounters(List(prunedTree.rootNode)) === arr.size) } + + test("weights at arbitrary scale") { + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(3, 10) + val rddWithUnitWeights = sc.parallelize(arr.map(_.asML.toInstance)) + val rddWithSmallWeights = rddWithUnitWeights.map { inst => + Instance(inst.label, 0.001, inst.features) + } + val rddWithBigWeights = rddWithUnitWeights.map { inst => + Instance(inst.label, 1000, inst.features) + } + val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2) + val unitWeightTrees = RandomForest.run(rddWithUnitWeights, strategy, 3, "all", 42L, None) + + val smallWeightTrees = RandomForest.run(rddWithSmallWeights, strategy, 3, "all", 42L, None) + unitWeightTrees.zip(smallWeightTrees).foreach { case (unitTree, smallWeightTree) => + TreeTests.checkEqual(unitTree, smallWeightTree) + } + + val bigWeightTrees = RandomForest.run(rddWithBigWeights, strategy, 3, "all", 42L, None) + unitWeightTrees.zip(bigWeightTrees).foreach { case (unitTree, bigWeightTree) => + TreeTests.checkEqual(unitTree, bigWeightTree) + } + } + + test("minWeightFraction and minInstancesPerNode") { + val data = Array( + Instance(0.0, 1.0, Vectors.dense(0.0)), + Instance(0.0, 1.0, Vectors.dense(0.0)), + Instance(0.0, 1.0, Vectors.dense(0.0)), + Instance(0.0, 1.0, Vectors.dense(0.0)), + Instance(1.0, 0.1, Vectors.dense(1.0)) + ) + val rdd = sc.parallelize(data) + val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2, + minWeightFractionPerNode = 0.5) + val Array(tree1) = RandomForest.run(rdd, strategy, 1, "all", 42L, None) + assert(tree1.depth === 0) + + strategy.minWeightFractionPerNode = 0.0 + val Array(tree2) = RandomForest.run(rdd, strategy, 1, "all", 42L, None) + assert(tree2.depth === 1) + + strategy.minInstancesPerNode = 2 + val Array(tree3) = RandomForest.run(rdd, strategy, 1, "all", 42L, None) + assert(tree3.depth === 0) + + strategy.minInstancesPerNode = 1 + val Array(tree4) = RandomForest.run(rdd, strategy, 1, "all", 42L, None) + assert(tree4.depth === 1) + } } private object RandomForestSuite { @@ -717,7 +789,7 @@ private object RandomForestSuite { else { nodes.head match { case i: InternalNode => getSumLeafCounters(i.leftChild :: i.rightChild :: nodes.tail, acc) - case l: LeafNode => getSumLeafCounters(nodes.tail, acc + l.impurityStats.count) + case l: LeafNode => getSumLeafCounters(nodes.tail, acc + l.impurityStats.rawCount) } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreePointSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreePointSuite.scala index f41abe48f2c5..ce473ebf52e9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreePointSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreePointSuite.scala @@ -27,7 +27,7 @@ class TreePointSuite extends SparkFunSuite { val ser = new KryoSerializer(conf).newInstance() - val point = new TreePoint(1.0, Array(1, 2, 3)) + val point = new TreePoint(1.0, Array(1, 2, 3), 1.0) val point2 = ser.deserialize[TreePoint](ser.serialize(point)) assert(point.label === point2.label) assert(point.binnedFeatures === point2.binnedFeatures) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index ae9794b87b08..f3096e28d3d4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -18,13 +18,15 @@ package org.apache.spark.ml.tree.impl import scala.collection.JavaConverters._ +import scala.util.Random import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.api.java.JavaRDD import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} -import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.tree._ +import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SparkSession} @@ -32,6 +34,7 @@ private[ml] object TreeTests extends SparkFunSuite { /** * Convert the given data to a DataFrame, and set the features and label metadata. + * * @param data Dataset. Categorical features and labels must already have 0-based indices. * This must be non-empty. * @param categoricalFeatures Map: categorical feature index to number of distinct values @@ -39,16 +42,22 @@ private[ml] object TreeTests extends SparkFunSuite { * @return DataFrame with metadata */ def setMetadata( - data: RDD[LabeledPoint], + data: RDD[_], categoricalFeatures: Map[Int, Int], numClasses: Int): DataFrame = { + val dataOfInstance: RDD[Instance] = data.map { + _ match { + case instance: Instance => instance + case labeledPoint: LabeledPoint => labeledPoint.toInstance + } + } val spark = SparkSession.builder() .sparkContext(data.sparkContext) .getOrCreate() import spark.implicits._ - val df = data.toDF() - val numFeatures = data.first().features.size + val df = dataOfInstance.toDF() + val numFeatures = dataOfInstance.first().features.size val featuresAttributes = Range(0, numFeatures).map { feature => if (categoricalFeatures.contains(feature)) { NominalAttribute.defaultAttr.withIndex(feature).withNumValues(categoricalFeatures(feature)) @@ -64,7 +73,7 @@ private[ml] object TreeTests extends SparkFunSuite { } val labelMetadata = labelAttribute.toMetadata() df.select(df("features").as("features", featuresMetadata), - df("label").as("label", labelMetadata)) + df("label").as("label", labelMetadata), df("weight")) } /** @@ -80,6 +89,7 @@ private[ml] object TreeTests extends SparkFunSuite { /** * Set label metadata (particularly the number of classes) on a DataFrame. + * * @param data Dataset. Categorical features and labels must already have 0-based indices. * This must be non-empty. * @param numClasses Number of classes label can take. If 0, mark as continuous. @@ -124,8 +134,8 @@ private[ml] object TreeTests extends SparkFunSuite { * make mistakes such as creating loops of Nodes. */ private def checkEqual(a: Node, b: Node): Unit = { - assert(a.prediction === b.prediction) - assert(a.impurity === b.impurity) + assert(a.prediction ~== b.prediction absTol 1e-8) + assert(a.impurity ~== b.impurity absTol 1e-8) (a, b) match { case (aye: InternalNode, bee: InternalNode) => assert(aye.split === bee.split) @@ -156,6 +166,7 @@ private[ml] object TreeTests extends SparkFunSuite { /** * Helper method for constructing a tree for testing. * Given left, right children, construct a parent node. + * * @param split Split for parent node * @return Parent node with children attached */ @@ -163,8 +174,8 @@ private[ml] object TreeTests extends SparkFunSuite { val leftImp = left.impurityStats val rightImp = right.impurityStats val parentImp = leftImp.copy.add(rightImp) - val leftWeight = leftImp.count / parentImp.count.toDouble - val rightWeight = rightImp.count / parentImp.count.toDouble + val leftWeight = leftImp.count / parentImp.count + val rightWeight = rightImp.count / parentImp.count val gain = parentImp.calculate() - (leftWeight * leftImp.calculate() + rightWeight * rightImp.calculate()) val pred = parentImp.predict diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index 91a8b14625a8..2c73700ee962 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -23,7 +23,7 @@ import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasWeightCol} +import org.apache.spark.ml.param.shared.HasWeightCol import org.apache.spark.ml.recommendation.{ALS, ALSModel} import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} @@ -205,8 +205,8 @@ object MLTestingUtils extends SparkFunSuite { seed: Long): Unit = { val (overSampledData, weightedData) = genEquivalentOversampledAndWeightedInstances( data, seed) - val weightedModel = estimator.set(estimator.weightCol, "weight").fit(weightedData) val overSampledModel = estimator.set(estimator.weightCol, "").fit(overSampledData) + val weightedModel = estimator.set(estimator.weightCol, "weight").fit(weightedData) modelEquals(weightedModel, overSampledModel) } @@ -228,7 +228,8 @@ object MLTestingUtils extends SparkFunSuite { List.fill(outlierRatio)(Instance(outlierLabel, 0.0001, f)) ++ List(Instance(l, w, f)) } val trueModel = estimator.set(estimator.weightCol, "").fit(data) - val outlierModel = estimator.set(estimator.weightCol, "weight").fit(outlierDS) + val outlierModel = estimator.set(estimator.weightCol, "weight") + .fit(outlierDS) modelEquals(trueModel, outlierModel) } @@ -241,7 +242,7 @@ object MLTestingUtils extends SparkFunSuite { estimator: E with HasWeightCol, modelEquals: (M, M) => Unit): Unit = { estimator.set(estimator.weightCol, "weight") - val models = Seq(0.001, 1.0, 1000.0).map { w => + val models = Seq(0.01, 1.0, 1000.0).map { w => val df = data.withColumn("weight", lit(w)) estimator.fit(df) } @@ -268,4 +269,20 @@ object MLTestingUtils extends SparkFunSuite { assert(newDatasetF.schema(featuresColName).dataType.equals(new ArrayType(FloatType, false))) (newDataset, newDatasetD, newDatasetF) } + + def modelPredictionEquals[M <: PredictionModel[_, M]]( + data: DataFrame, + compareFunc: (Double, Double) => Boolean, + fractionInTol: Double)( + model1: M, + model2: M): Unit = { + val pred1 = model1.transform(data).select(model1.getPredictionCol).collect() + val pred2 = model2.transform(data).select(model2.getPredictionCol).collect() + val inTol = pred1.zip(pred2).count { case (p1, p2) => + val x = p1.getDouble(0) + val y = p2.getDouble(0) + compareFunc(x, y) + } + assert(inTol / pred1.length.toDouble >= fractionInTol) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 34bc303ac607..8378a599362a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -73,7 +73,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) - val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) @@ -100,7 +100,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { maxDepth = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2)) - val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) @@ -116,7 +116,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Gini, maxDepth = 3, numClasses = 2, maxBins = 100) - val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) @@ -133,7 +133,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Gini, maxDepth = 3, numClasses = 2, maxBins = 100) - val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) @@ -150,7 +150,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Entropy, maxDepth = 3, numClasses = 2, maxBins = 100) - val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) @@ -167,7 +167,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Entropy, maxDepth = 3, numClasses = 2, maxBins = 100) - val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) @@ -183,7 +183,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClasses = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) - val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy) assert(strategy.isMulticlassClassification) assert(metadata.isUnordered(featureIndex = 0)) assert(metadata.isUnordered(featureIndex = 1)) @@ -240,7 +240,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { numClasses = 3, maxBins = maxBins, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) assert(strategy.isMulticlassClassification) - val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy) assert(metadata.isUnordered(featureIndex = 0)) assert(metadata.isUnordered(featureIndex = 1)) @@ -288,7 +288,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClasses = 3, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3)) assert(strategy.isMulticlassClassification) - val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy) assert(metadata.isUnordered(featureIndex = 0)) val model = DecisionTree.train(rdd, strategy) @@ -310,7 +310,7 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { numClasses = 3, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) assert(strategy.isMulticlassClassification) - val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML), strategy) + val metadata = DecisionTreeMetadata.buildMetadata(rdd.map(_.asML.toInstance), strategy) assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala index d0f02dd966bd..078c6e6fff9f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala @@ -18,23 +18,63 @@ package org.apache.spark.mllib.tree import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator} +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.tree.impurity._ /** * Test suites for `GiniAggregator` and `EntropyAggregator`. */ class ImpuritySuite extends SparkFunSuite { + + private val seed = 42 + test("Gini impurity does not support negative labels") { val gini = new GiniAggregator(2) intercept[IllegalArgumentException] { - gini.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0) + gini.update(Array(0.0, 1.0, 2.0), 0, -1, 3, 0.0) } } test("Entropy does not support negative labels") { val entropy = new EntropyAggregator(2) intercept[IllegalArgumentException] { - entropy.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0) + entropy.update(Array(0.0, 1.0, 2.0), 0, -1, 3, 0.0) + } + } + + test("Classification impurities are insensitive to scaling") { + val rng = new scala.util.Random(seed) + val weightedCounts = Array.fill(5)(rng.nextDouble()) + val smallWeightedCounts = weightedCounts.map(_ * 0.0001) + val largeWeightedCounts = weightedCounts.map(_ * 10000) + Seq(Gini, Entropy).foreach { impurity => + val impurity1 = impurity.calculate(weightedCounts, weightedCounts.sum) + assert(impurity.calculate(smallWeightedCounts, smallWeightedCounts.sum) + ~== impurity1 relTol 0.005) + assert(impurity.calculate(largeWeightedCounts, largeWeightedCounts.sum) + ~== impurity1 relTol 0.005) } } + + test("Regression impurities are insensitive to scaling") { + def computeStats(samples: Seq[Double], weights: Seq[Double]): (Double, Double, Double) = { + samples.zip(weights).foldLeft((0.0, 0.0, 0.0)) { case ((wn, wy, wyy), (y, w)) => + (wn + w, wy + w * y, wyy + w * y * y) + } + } + val rng = new scala.util.Random(seed) + val samples = Array.fill(10)(rng.nextDouble()) + val _weights = Array.fill(10)(rng.nextDouble()) + val smallWeights = _weights.map(_ * 0.0001) + val largeWeights = _weights.map(_ * 10000) + val (count, sum, sumSquared) = computeStats(samples, _weights) + Seq(Variance).foreach { impurity => + val impurity1 = impurity.calculate(count, sum, sumSquared) + val (smallCount, smallSum, smallSumSquared) = computeStats(samples, smallWeights) + val (largeCount, largeSum, largeSumSquared) = computeStats(samples, largeWeights) + assert(impurity.calculate(smallCount, smallSum, smallSumSquared) ~== impurity1 relTol 0.005) + assert(impurity.calculate(largeCount, largeSum, largeSumSquared) ~== impurity1 relTol 0.005) + } + } + }