From bc7e824cbc7e8995dc9c04df8ece9e7a45fea168 Mon Sep 17 00:00:00 2001 From: Yuewei Na Date: Mon, 6 Jun 2016 17:55:54 -0700 Subject: [PATCH 01/18] Modify impurity implementations, NOTICE: a further modification is needed(getCalculator & fromString method) --- .../spark/mllib/tree/impurity/Entropy.scala | 5 + .../spark/mllib/tree/impurity/Gini.scala | 5 + .../spark/mllib/tree/impurity/Impurity.scala | 5 + .../spark/mllib/tree/impurity/Variance.scala | 5 + .../mllib/tree/impurity/WeightedGini.scala | 210 ++++++++++++++++++ 5 files changed, 230 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/WeightedGini.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index ff7700d2d1b7..d9f089842e3b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -138,6 +138,11 @@ private[spark] class EntropyCalculator(stats: Array[Double]) extends ImpurityCal */ def count: Long = stats.sum.toLong + /** + * Weighted summary statistics of data points, which in this case assume uniform class weights + */ + def weightedCount: Double = stats.sum + /** * Prediction which should be made based on the sufficient statistics. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 58dc79b7398e..ded6488ddc79 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -134,6 +134,11 @@ private[spark] class GiniCalculator(stats: Array[Double]) extends ImpurityCalcul */ def count: Long = stats.sum.toLong + /** + * Weighted summary statistics of data points, which in this case assume uniform class weights + */ + def weightedCount: Double = stats.sum + /** * Prediction which should be made based on the sufficient statistics. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 65f0163ec605..89cf81835a03 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -147,6 +147,11 @@ private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) exten */ def count: Long + /** + * Weighted summary statistics of data points + */ + def weightedCount: Double + /** * Prediction which should be made based on the sufficient statistics. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 2423516123b8..9227cc14540a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -122,6 +122,11 @@ private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCa */ def count: Long = stats(0).toLong + /** + * Weighted summary statistics of data points, which in this case assume uniform class weights + */ + def weightedCount: Double = stats(0) + /** * Prediction which should be made based on the sufficient statistics. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/WeightedGini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/WeightedGini.scala new file mode 100644 index 000000000000..42e417278d2f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/WeightedGini.scala @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impurity + +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} + +/** + * :: Experimental :: + * Class for calculating the + * [[http://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity Gini impurity]] + * during binary classification. + */ +@Since("1.0.0") +@Experimental +object WeightedGini extends Impurity { + + /** + * :: DeveloperApi :: + * information calculation for multiclass classification + * @param weightedCounts Array[Double] with counts for each label + * @param weightedTotalCount sum of counts for all labels + * @return information value, or 0 if totalCount = 0 + */ + @Since("1.1.0") + @DeveloperApi + override def calculate(weightedCounts: Array[Double], weightedTotalCount: Double): Double = { + if (weightedTotalCount == 0) { + return 0 + } + val numClasses = weightedCounts.length + var impurity = 1.0 + var classIndex = 0 + while (classIndex < numClasses) { + val freq = weightedCounts(classIndex) / weightedTotalCount + impurity -= freq * freq + classIndex += 1 + } + impurity + } + + /** + * :: DeveloperApi :: + * variance calculation + * @param count number of instances + * @param sum sum of labels + * @param sumSquares summation of squares of the labels + * @return information value, or 0 if count = 0 + */ + @Since("1.0.0") + @DeveloperApi + override def calculate(count: Double, sum: Double, sumSquares: Double): Double = + throw new UnsupportedOperationException("Gini.calculate") + + /** + * Get this impurity instance. + * This is useful for passing impurity parameters to a Strategy in Java. + */ + @Since("1.1.0") + def instance: this.type = this + +} + +/** + * Class for updating views of a vector of sufficient statistics, + * in order to compute impurity from a sample. + * Note: Instances of this class do not hold the data; they operate on views of the data. + * @param numClasses Number of classes for label. + * @param weights Weights of classes + */ +private[tree] class WeightedGiniAggregator(numClasses: Int, weights: Array[Double]) + extends ImpurityAggregator(numClasses) with Serializable { + + /** + * Update stats for one (node, feature, bin) with the given label. + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param offset Start index of stats for this (node, feature, bin). + */ + def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = { + if (label >= statsSize) { + throw new IllegalArgumentException(s"WeightedGiniAggregator given label $label" + + s" but requires label < numClasses (= $statsSize).") + } + if (label < 0) { + throw new IllegalArgumentException(s"WeightedGiniAggregator given label $label" + + s"but requires label is non-negative.") + } + allStats(offset + label.toInt) += instanceWeight + } + + /** + * Get an [[ImpurityCalculator]] for a (node, feature, bin). + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param offset Start index of stats for this (node, feature, bin). + */ + def getCalculator(allStats: Array[Double], offset: Int): WeightedGiniCalculator = { + new WeightedGiniCalculator(allStats.view(offset, offset + statsSize).toArray, weights) + } +} + +/** + * Stores statistics for one (node, feature, bin) for calculating impurity. + * Unlike [[WeightedGiniAggregator]], this class stores its own data and is for a specific + * (node, feature, bin). + * @param stats Array of sufficient statistics for a (node, feature, bin). + * @param weights Weights of classes + */ +private[spark] class WeightedGiniCalculator(stats: Array[Double], weights: Array[Double]) + extends ImpurityCalculator(stats) { + + var weightedStats = stats.zip(weights).map(x => x._1 * x._2) + /** + * Make a deep copy of this [[ImpurityCalculator]]. + */ + def copy: WeightedGiniCalculator = new WeightedGiniCalculator(stats.clone(), weights.clone()) + + /** + * Calculate the impurity from the stored sufficient statistics. + */ + def calculate(): Double = WeightedGini.calculate(weightedStats, weightedStats.sum) + + /** + * Number of data points accounted for in the sufficient statistics. + */ + def count: Long = stats.sum.toLong + + /** + * Weighted summary statistics of data points + */ + def weightedCount: Double = weightedStats.sum + + /** + * Prediction which should be made based on the sufficient statistics. + */ + def predict: Double = if (count == 0) { + 0 + } else { + indexOfLargestArrayElement(weightedStats) + } + + /** + * Probability of the label given by [[predict]]. + */ + override def prob(label: Double): Double = { + val lbl = label.toInt + require(lbl < stats.length, + s"WeightedGiniCalculator.prob given invalid label: $lbl (should be < ${stats.length}") + require(lbl >= 0, "WeightedGiniImpurity does not support negative labels") + val cnt = weightedCount + if (cnt == 0) { + 0 + } else { + weightedStats(lbl) / cnt + } + } + + override def toString: String = s"WeightedGiniCalculator(stats = [${stats.mkString(", ")}])" + + /** + * Add the stats from another calculator into this one, modifying and returning this calculator. + * Update the weightedStats at the same time + */ + override def add(other: ImpurityCalculator): ImpurityCalculator = { + require(stats.length == other.stats.length, + s"Two ImpurityCalculator instances cannot be added with different counts sizes." + + s" Sizes are ${stats.length} and ${other.stats.length}.") + val otherCalculator = other.asInstanceOf[WeightedGiniCalculator] + var i = 0 + val len = other.stats.length + while (i < len) { + stats(i) += other.stats(i) + weightedStats(i) += otherCalculator.weightedStats(i) + i += 1 + } + this + } + + /** + * Subtract the stats from another calculator from this one, modifying and returning this + * calculator. Update the weightedStats at the same time + */ + override def subtract(other: ImpurityCalculator): ImpurityCalculator = { + require(stats.length == other.stats.length, + s"Two ImpurityCalculator instances cannot be subtracted with different counts sizes." + + s" Sizes are ${stats.length} and ${other.stats.length}.") + val otherCalculator = other.asInstanceOf[WeightedGiniCalculator] + var i = 0 + val len = other.stats.length + while (i < len) { + stats(i) -= other.stats(i) + weightedStats(i) -= otherCalculator.weightedStats(i) + i += 1 + } + this + } +} From df3b4e7831995aa4a53b914732152a515629a057 Mon Sep 17 00:00:00 2001 From: Yuewei Na Date: Tue, 7 Jun 2016 14:42:53 -0700 Subject: [PATCH 02/18] save changes, but compile error exists --- .../DecisionTreeClassifier.scala | 3 +++ .../ml/tree/impl/DTStatsAggregator.scala | 2 +- .../ml/tree/impl/DecisionTreeMetadata.scala | 6 +++-- .../spark/ml/tree/impl/RandomForest.scala | 11 +++++++-- .../org/apache/spark/ml/tree/treeParams.scala | 23 +++++++++++++++---- .../mllib/tree/configuration/Strategy.scala | 20 ++++++++++++---- .../mllib/tree/impurity/WeightedGini.scala | 4 ++-- 7 files changed, 53 insertions(+), 16 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 881dcefb79be..a4b2c1113cdd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -82,6 +82,9 @@ class DecisionTreeClassifier @Since("1.4.0") ( @Since("1.6.0") override def setSeed(value: Long): this.type = super.setSeed(value) + @Since("2.0.0") + override def setClassWeights(value: Array[Double]): this.type = super.setClassWeights(value) + override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala index 61091bb803e4..3d175006b9ab 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala @@ -20,7 +20,6 @@ package org.apache.spark.ml.tree.impl import org.apache.spark.mllib.tree.impurity._ - /** * DecisionTree statistics aggregator for a node. * This holds a flat array of statistics for a set of (features, bins) @@ -38,6 +37,7 @@ private[spark] class DTStatsAggregator( case Gini => new GiniAggregator(metadata.numClasses) case Entropy => new EntropyAggregator(metadata.numClasses) case Variance => new VarianceAggregator() + case WeightedGini => new WeightedGiniAggregator(metadata.numClasses, metadata.classWeights) case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala index 442f52bf0231..a8ad966adf1c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala @@ -53,7 +53,8 @@ private[spark] class DecisionTreeMetadata( val minInstancesPerNode: Int, val minInfoGain: Double, val numTrees: Int, - val numFeaturesPerNode: Int) extends Serializable { + val numFeaturesPerNode: Int, + val classWeights: Array[Double]) extends Serializable { def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex) @@ -207,7 +208,8 @@ private[spark] object DecisionTreeMetadata extends Logging { new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max, strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins, strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth, - strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode) + strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode, + strategy.classWeights) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 71c8c42ce5eb..fe83d602764a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -657,8 +657,15 @@ private[spark] object RandomForest extends Logging { val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 val rightImpurity = rightImpurityCalculator.calculate() - val leftWeight = leftCount / totalCount.toDouble - val rightWeight = rightCount / totalCount.toDouble + // Weighted count is equivalent to normal count using Gini or Entropy impurity + // where the class weights are assumed to be uniform + val leftWeightedCount = leftImpurityCalculator.weightedCount + val rightWeightedCount = rightImpurityCalculator.weightedCount + + val totalWeightedCount = leftWeightedCount + rightWeightedCount + + val leftWeight = leftWeightedCount / totalWeightedCount.toDouble + val rightWeight = rightWeightedCount / totalWeightedCount.toDouble val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index d7559f8950c3..b27b2045b22a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -18,13 +18,12 @@ package org.apache.spark.ml.tree import scala.util.Try - import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} -import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} +import org.apache.spark.mllib.tree.impurity.{WeightedGini, Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError} import org.apache.spark.sql.types.{DataType, DoubleType, StructType} @@ -102,8 +101,16 @@ private[ml] trait DecisionTreeParams extends PredictorParams " algorithm will cache node IDs for each instance. Caching can speed up training of deeper" + " trees.") + /** + * An array that stores the weights of class labels. All elements must be non-negative. + * (default = Array()) + * @group expertParam + */ + final val classWeights: DoubleArrayParam = new DoubleArrayParam(this, "classWeights", "An array" + + " that stores the weights of class labels. All elements must be non-negative.") + setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0, - maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10) + maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10, classWeights -> Array()) /** @group setParam */ def setMaxDepth(value: Int): this.type = set(maxDepth, value) @@ -144,6 +151,12 @@ private[ml] trait DecisionTreeParams extends PredictorParams /** @group expertGetParam */ final def getCacheNodeIds: Boolean = $(cacheNodeIds) + /** @group expertSetParam */ + def setClassWeights(value: Array[Double]): this.type = set(classWeights, value) + + /** @group expertGetParam */ + final def getClassWeights: Array[Double] = $(classWeights) + /** * Specifies how often to checkpoint the cached node IDs. * E.g. 10 means that the cache will get checkpointed every 10 iterations. @@ -174,6 +187,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams strategy.numClasses = numClasses strategy.categoricalFeaturesInfo = categoricalFeatures strategy.subsamplingRate = subsamplingRate + strategy.classWeights = getClassWeights strategy } } @@ -207,6 +221,7 @@ private[ml] trait TreeClassifierParams extends Params { getImpurity match { case "entropy" => OldEntropy case "gini" => OldGini + case "weightedgini" => WeightedGini case _ => // Should never happen because of check in setter method. throw new RuntimeException( @@ -217,7 +232,7 @@ private[ml] trait TreeClassifierParams extends Params { private[ml] object TreeClassifierParams { // These options should be lowercase. - final val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase) + final val supportedImpurities: Array[String] = Array("entropy", "gini", "weightedgini").map(_.toLowerCase) } private[ml] trait DecisionTreeClassifierParams diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index b34e1b1b56c4..f875780680f5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Since import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ -import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} +import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance, WeightedGini} /** * Stores all the configuration options for tree construction @@ -32,6 +32,7 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] * @param impurity Criterion used for information gain calculation. * Supported for Classification: [[org.apache.spark.mllib.tree.impurity.Gini]], + * [[org.apache.spark.mllib.tree.impurity.WeightedGini]], * [[org.apache.spark.mllib.tree.impurity.Entropy]]. * Supported for Regression: [[org.apache.spark.mllib.tree.impurity.Variance]]. * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means @@ -80,7 +81,8 @@ class Strategy @Since("1.3.0") ( @Since("1.0.0") @BeanProperty var maxMemoryInMB: Int = 256, @Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1, @Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false, - @Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10) extends Serializable { + @Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10, + @Since("2.0.0") @BeanProperty var classWeights: Array[Double] = Array()) extends Serializable { /** */ @@ -140,9 +142,9 @@ class Strategy @Since("1.3.0") ( require(numClasses >= 2, s"DecisionTree Strategy for Classification must have numClasses >= 2," + s" but numClasses = $numClasses.") - require(Set(Gini, Entropy).contains(impurity), + require(Set(Gini, Entropy, WeightedGini).contains(impurity), s"DecisionTree Strategy given invalid impurity for Classification: $impurity." + - s" Valid settings: Gini, Entropy") + s" Valid settings: Gini, Entropy, WeightedGini") case Regression => require(impurity == Variance, s"DecisionTree Strategy given invalid impurity for Regression: $impurity." + @@ -163,6 +165,14 @@ class Strategy @Since("1.3.0") ( require(subsamplingRate > 0 && subsamplingRate <= 1, s"DecisionTree Strategy requires subsamplingRate <=1 and >0, but was given " + s"$subsamplingRate") + if (impurity == WeightedGini) { + require(numClasses == classWeights.length, + s"DecisionTree Strategy requires the number of class weights be the same as the " + + s"number of classes, but there are $numClasses classes and ${classWeights.length} weights") + require(classWeights.forall((x: Double) => x >= 0), + s"DecisionTree Strategy requires the all the class weights be non-negative" + + s", but at least one of them is negative") + } } /** @@ -172,7 +182,7 @@ class Strategy @Since("1.3.0") ( def copy: Strategy = { new Strategy(algo, impurity, maxDepth, numClasses, maxBins, quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, - maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointInterval) + maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointInterval, classWeights) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/WeightedGini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/WeightedGini.scala index 42e417278d2f..195c644d395a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/WeightedGini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/WeightedGini.scala @@ -64,7 +64,7 @@ object WeightedGini extends Impurity { @Since("1.0.0") @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = - throw new UnsupportedOperationException("Gini.calculate") + throw new UnsupportedOperationException("WeightedGini.calculate") /** * Get this impurity instance. @@ -82,7 +82,7 @@ object WeightedGini extends Impurity { * @param numClasses Number of classes for label. * @param weights Weights of classes */ -private[tree] class WeightedGiniAggregator(numClasses: Int, weights: Array[Double]) +private[spark] class WeightedGiniAggregator(numClasses: Int, weights: Array[Double]) extends ImpurityAggregator(numClasses) with Serializable { /** From 7bcabdac3d54ed9c682de5493da89f26e8a8e55a Mon Sep 17 00:00:00 2001 From: Yuewei Na Date: Tue, 7 Jun 2016 15:37:31 -0700 Subject: [PATCH 03/18] simple testSuites and run properly --- .../spark/mllib/tree/configuration/Strategy.scala | 3 ++- .../spark/mllib/tree/impurity/WeightedGini.scala | 14 +++++++------- .../DecisionTreeClassifierSuite.scala | 4 ++-- .../spark/ml/tree/impl/RandomForestSuite.scala | 8 ++++---- 4 files changed, 15 insertions(+), 14 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index f875780680f5..03667a5dac71 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -82,7 +82,8 @@ class Strategy @Since("1.3.0") ( @Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1, @Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false, @Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10, - @Since("2.0.0") @BeanProperty var classWeights: Array[Double] = Array()) extends Serializable { + @Since("2.0.0") @BeanProperty var classWeights: Array[Double] = Array(1, 1)) + extends Serializable { /** */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/WeightedGini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/WeightedGini.scala index 195c644d395a..952d1ede3ad6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/WeightedGini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/WeightedGini.scala @@ -80,9 +80,9 @@ object WeightedGini extends Impurity { * in order to compute impurity from a sample. * Note: Instances of this class do not hold the data; they operate on views of the data. * @param numClasses Number of classes for label. - * @param weights Weights of classes + * @param classWeights Weights of classes */ -private[spark] class WeightedGiniAggregator(numClasses: Int, weights: Array[Double]) +private[spark] class WeightedGiniAggregator(numClasses: Int, classWeights: Array[Double]) extends ImpurityAggregator(numClasses) with Serializable { /** @@ -108,7 +108,7 @@ private[spark] class WeightedGiniAggregator(numClasses: Int, weights: Array[Doub * @param offset Start index of stats for this (node, feature, bin). */ def getCalculator(allStats: Array[Double], offset: Int): WeightedGiniCalculator = { - new WeightedGiniCalculator(allStats.view(offset, offset + statsSize).toArray, weights) + new WeightedGiniCalculator(allStats.view(offset, offset + statsSize).toArray, classWeights) } } @@ -117,16 +117,16 @@ private[spark] class WeightedGiniAggregator(numClasses: Int, weights: Array[Doub * Unlike [[WeightedGiniAggregator]], this class stores its own data and is for a specific * (node, feature, bin). * @param stats Array of sufficient statistics for a (node, feature, bin). - * @param weights Weights of classes + * @param classWeights Weights of classes */ -private[spark] class WeightedGiniCalculator(stats: Array[Double], weights: Array[Double]) +private[spark] class WeightedGiniCalculator(stats: Array[Double], classWeights: Array[Double]) extends ImpurityCalculator(stats) { - var weightedStats = stats.zip(weights).map(x => x._1 * x._2) + var weightedStats = stats.zip(classWeights).map(x => x._1 * x._2) /** * Make a deep copy of this [[ImpurityCalculator]]. */ - def copy: WeightedGiniCalculator = new WeightedGiniCalculator(stats.clone(), weights.clone()) + def copy: WeightedGiniCalculator = new WeightedGiniCalculator(stats.clone(), classWeights.clone()) /** * Calculate the impurity from the stored sufficient statistics. diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 089d30abb5ef..529a1d325811 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -79,7 +79,7 @@ class DecisionTreeClassifierSuite val numClasses = 2 compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures, numClasses) } - +/* test("Binary classification stump with fixed labels 0,1 for Entropy,Gini") { val dt = new DecisionTreeClassifier() .setMaxDepth(3) @@ -91,7 +91,7 @@ class DecisionTreeClassifierSuite compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) } } - } + }*/ test("Multiclass classification stump with 3-ary (unordered) categorical features") { val rdd = categoricalDataPointsForMulticlassRDD diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index dcc2f305df75..b32822477224 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -93,7 +93,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, Map(), Set(), Array(6), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0 + 0, 0, 0.0, 0, 0, Array() ) val featureSamples = Array.fill(200000)(math.random) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) @@ -110,7 +110,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, Map(), Set(), Array(5), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0 + 0, 0, 0.0, 0, 0, Array() ) val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) @@ -124,7 +124,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, Map(), Set(), Array(3), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0 + 0, 0, 0.0, 0, 0, Array() ) val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) @@ -138,7 +138,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, Map(), Set(), Array(3), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0 + 0, 0, 0.0, 0, 0, Array() ) val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) From 61c48588f9bae554e23c1436936c5653c36ae217 Mon Sep 17 00:00:00 2001 From: Yuewei Na Date: Tue, 7 Jun 2016 15:57:57 -0700 Subject: [PATCH 04/18] add unbalenced data test case, modify label prediction process s.t. the predictions are correct with Impurity=WeightedGini --- .../ProbabilisticClassifier.scala | 6 +- .../DecisionTreeClassifierSuite.scala | 69 +++++++++++++++++++ 2 files changed, 70 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 59277d0f42b3..01fff628885d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -127,11 +127,7 @@ abstract class ProbabilisticClassificationModel[ numColsOutput += 1 } if ($(predictionCol).nonEmpty) { - val predUDF = if ($(rawPredictionCol).nonEmpty) { - udf(raw2prediction _).apply(col($(rawPredictionCol))) - } else if ($(probabilityCol).nonEmpty) { - udf(probability2prediction _).apply(col($(probabilityCol))) - } else { + val predUDF = { val predictUDF = udf { (features: Any) => predict(features.asInstanceOf[FeaturesType]) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 529a1d325811..19279de01d14 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -30,6 +30,9 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} +import scala.io.Source +import org.apache.spark.mllib.evaluation.MulticlassMetrics + class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -69,6 +72,72 @@ class DecisionTreeClassifierSuite // Tests calling train() ///////////////////////////////////////////////////////////////////////////// + test("classification with class weights") { + val buf = Source.fromFile("/Users/yueweina/Documents/spark/data/mllib/unbalenced_train.csv").getLines + val header = buf.take(1).next() + var data = new Array[LabeledPoint](1000) + var idx : Int = 0 + for (row <- buf) { + val cols = row.split(",").map(_.trim) + // scalastyle:off println + //println(s"${cols(0)}|${cols(1)}|${cols(2)}|${cols(3)}|${cols(4)}") + //println( cols(0).getClass) + // scalastyle:on println + data(idx) = new LabeledPoint(cols(0).toDouble, + Vectors.dense(cols(1).toDouble, cols(2).toDouble)) + idx += 1 + } + // scalastyle:off println + //println(data) + // scalastyle:on println + val IrIsRDD = sc.parallelize(data) + val dt = new DecisionTreeClassifier() + .setImpurity("weightedgini") + .setMaxDepth(3) + .setMaxBins(400) + .setMinInstancesPerNode(1) + .setClassWeights(Array(1, 10000)) + val categoricalFeatures: Map[Int, Int] = Map() + val newData: DataFrame = TreeTests.setMetadata(IrIsRDD, categoricalFeatures, 2) + val newTree = dt.fit(newData) + val predoutput = newTree.transform(newData) + // scalastyle:off println + predoutput.show(1000) + println(newTree.toDebugString) + // scalastyle:on println + /* + val predictionsAndLabels = predoutput.select("prediction", "label") + .map(row => (row.getDouble(0), row.getDouble(1))) + val metrics = new MulticlassMetrics(predictionsAndLabels) + val confusionMatrix = metrics.confusionMatrix + // scalastyle:off println + println(confusionMatrix.toString()) + // scalastyle:on println + val buf2 = Source.fromFile("/Users/yueweina/Documents/spark/data/mllib/unbalenced_test.csv").getLines + val header2 = buf2.take(1).next() + var data2 = new Array[LabeledPoint](250) + idx = 0 + for (row <- buf2) { + val cols = row.split(",").map(_.trim) + data2(idx) = new LabeledPoint(cols(0).toDouble, + Vectors.dense(cols(1).toDouble, cols(2).toDouble)) + idx += 1 + } + // scalastyle:off println + //println(data) + // scalastyle:on println + val IrIsRDD2 = sc.parallelize(data2) + val newData2: DataFrame = TreeTests.setMetadata(IrIsRDD2, categoricalFeatures, 2) + val predoutput2 = newTree.transform(newData2) + val predictions = predoutput2.select("prediction", "label") + .map(row => (row.getDouble(0), row.getDouble(1))) + val metrics2 = new MulticlassMetrics(predictions) + val confusionMatrix2 = metrics2.confusionMatrix + // scalastyle:off println + println(confusionMatrix2.toString()) + // scalastyle:on println */ + } + test("Binary classification stump with ordered categorical features") { val dt = new DecisionTreeClassifier() .setImpurity("gini") From aeb08563113f14f950ae9956dbe5104978a90196 Mon Sep 17 00:00:00 2001 From: Yuewei Na Date: Tue, 7 Jun 2016 16:18:45 -0700 Subject: [PATCH 05/18] Make Decision Tree predicions correct without changing the base class ProbClassifier, but changing the definition of Impurity and DecTreClassifier --- .../spark/ml/classification/DecisionTreeClassifier.scala | 2 +- .../spark/ml/classification/ProbabilisticClassifier.scala | 6 +++++- .../org/apache/spark/mllib/tree/impurity/Impurity.scala | 1 + .../org/apache/spark/mllib/tree/impurity/WeightedGini.scala | 2 +- .../ml/classification/DecisionTreeClassifierSuite.scala | 2 +- 5 files changed, 9 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index a4b2c1113cdd..cee45e9262ca 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -171,7 +171,7 @@ class DecisionTreeClassificationModel private[ml] ( } override protected def predictRaw(features: Vector): Vector = { - Vectors.dense(rootNode.predictImpl(features).impurityStats.stats.clone()) + Vectors.dense(rootNode.predictImpl(features).impurityStats.weightedStats.clone()) } override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 01fff628885d..59277d0f42b3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -127,7 +127,11 @@ abstract class ProbabilisticClassificationModel[ numColsOutput += 1 } if ($(predictionCol).nonEmpty) { - val predUDF = { + val predUDF = if ($(rawPredictionCol).nonEmpty) { + udf(raw2prediction _).apply(col($(rawPredictionCol))) + } else if ($(probabilityCol).nonEmpty) { + udf(probability2prediction _).apply(col($(probabilityCol))) + } else { val predictUDF = udf { (features: Any) => predict(features.asInstanceOf[FeaturesType]) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 89cf81835a03..255fbd4d6de0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -99,6 +99,7 @@ private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Ser */ private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) extends Serializable { + val weightedStats: Array[Double] = stats /** * Make a deep copy of this [[ImpurityCalculator]]. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/WeightedGini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/WeightedGini.scala index 952d1ede3ad6..293ac71dd139 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/WeightedGini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/WeightedGini.scala @@ -122,7 +122,7 @@ private[spark] class WeightedGiniAggregator(numClasses: Int, classWeights: Array private[spark] class WeightedGiniCalculator(stats: Array[Double], classWeights: Array[Double]) extends ImpurityCalculator(stats) { - var weightedStats = stats.zip(classWeights).map(x => x._1 * x._2) + override val weightedStats = stats.zip(classWeights).map(x => x._1 * x._2) /** * Make a deep copy of this [[ImpurityCalculator]]. */ diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 19279de01d14..7a946bcf7f5b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -96,7 +96,7 @@ class DecisionTreeClassifierSuite .setMaxDepth(3) .setMaxBins(400) .setMinInstancesPerNode(1) - .setClassWeights(Array(1, 10000)) + .setClassWeights(Array(1, 1)) val categoricalFeatures: Map[Int, Int] = Map() val newData: DataFrame = TreeTests.setMetadata(IrIsRDD, categoricalFeatures, 2) val newTree = dt.fit(newData) From 4dc3e325caf96e7aae1fec7d0f5db290d0bd7195 Mon Sep 17 00:00:00 2001 From: Yuewei Na Date: Thu, 9 Jun 2016 18:36:19 -0700 Subject: [PATCH 06/18] 1.put classweight def to the right area 2.change interfaces of getOldStrategy 3.make classweights can be passed when reconstructing the tree, including json read/write --- .../DecisionTreeClassifier.scala | 2 +- .../RandomForestClassifier.scala | 3 +- .../ml/regression/DecisionTreeRegressor.scala | 2 +- .../ml/regression/RandomForestRegressor.scala | 3 +- .../org/apache/spark/ml/tree/treeModels.scala | 22 +++++++-- .../org/apache/spark/ml/tree/treeParams.scala | 46 ++++++++++--------- .../spark/mllib/tree/impurity/Impurity.scala | 4 +- .../DecisionTreeClassifierSuite.scala | 2 +- .../RandomForestClassifierSuite.scala | 2 +- .../RandomForestRegressorSuite.scala | 3 +- 10 files changed, 56 insertions(+), 33 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index cee45e9262ca..53aaeb2fbf3a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -122,7 +122,7 @@ class DecisionTreeClassifier @Since("1.4.0") ( categoricalFeatures: Map[Int, Int], numClasses: Int): OldStrategy = { super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity, - subsamplingRate = 1.0) + subsamplingRate = 1.0, getClassWeights) } @Since("1.4.1") diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index b3c074f83925..d6a1b54397a8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -104,7 +104,8 @@ class RandomForestClassifier @Since("1.4.0") ( val numClasses: Int = getNumClasses(dataset) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses) val strategy = - super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity) + super.getOldStrategy(categoricalFeatures, numClasses, + OldAlgo.Classification, getOldImpurity, getClassWeights) val instr = Instrumentation.create(this, oldDataset) instr.logParams(params: _*) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index c4df9d11127f..b2fe5ded6179 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -117,7 +117,7 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S /** (private[ml]) Create a Strategy instance to use with the old API. */ private[ml] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = { super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity, - subsamplingRate = 1.0) + subsamplingRate = 1.0, classWeights = Array()) } @Since("1.4.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index a6dbf21d55e2..1fd42303b60c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -98,7 +98,8 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val strategy = - super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity) + super.getOldStrategy(categoricalFeatures, numClasses = 0, + OldAlgo.Regression, getOldImpurity, Array()) val instr = Instrumentation.create(this, oldDataset) instr.logParams(params: _*) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index 56c85c9b53e1..9451e0ade56b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -342,9 +342,15 @@ private[ml] object DecisionTreeModelReadWrite { Param.jsonDecode[String](compact(render(impurityJson))) } + val classWeights: Array[Double] = { + val classWeightsJson: JValue = metadata.getParamValue("classWeights") + compact(render(classWeightsJson)).split("\\[|,|\\]") + .filter((s: String) => s.length() != 0).map((s: String) => s.toDouble) + } + val dataPath = new Path(path, "data").toString val data = sqlContext.read.parquet(dataPath).as[NodeData] - buildTreeFromNodes(data.collect(), impurityType) + buildTreeFromNodes(data.collect(), impurityType, classWeights) } /** @@ -353,7 +359,8 @@ private[ml] object DecisionTreeModelReadWrite { * @param impurityType Impurity type for this tree * @return Root node of reconstructed tree */ - def buildTreeFromNodes(data: Array[NodeData], impurityType: String): Node = { + def buildTreeFromNodes(data: Array[NodeData], impurityType: String, + classWeights: Array[Double]): Node = { // Load all nodes, sorted by ID. val nodes = data.sortBy(_.id) // Sanity checks; could remove @@ -365,7 +372,8 @@ private[ml] object DecisionTreeModelReadWrite { // traversal, this guarantees that child nodes will be built before parent nodes. val finalNodes = new Array[Node](nodes.length) nodes.reverseIterator.foreach { case n: NodeData => - val impurityStats = ImpurityCalculator.getCalculator(impurityType, n.impurityStats) + val impurityStats = ImpurityCalculator.getCalculator(impurityType, + n.impurityStats, classWeights) val node = if (n.leftChild != -1) { val leftChild = finalNodes(n.leftChild) val rightChild = finalNodes(n.rightChild) @@ -437,6 +445,12 @@ private[ml] object EnsembleModelReadWrite { Param.jsonDecode[String](compact(render(impurityJson))) } + val classWeights: Array[Double] = { + val classWeightsJson: JValue = metadata.getParamValue("classWeights") + val classWeightsVector = Param.jsonDecode[Vector](compact(render(classWeightsJson))) + classWeightsVector.toArray + } + val treesMetadataPath = new Path(path, "treesMetadata").toString val treesMetadataRDD: RDD[(Int, (Metadata, Double))] = sql.read.parquet(treesMetadataPath) .select("treeID", "metadata", "weights").as[(Int, String, Double)].rdd.map { @@ -454,7 +468,7 @@ private[ml] object EnsembleModelReadWrite { val rootNodesRDD: RDD[(Int, Node)] = nodeData.rdd.map(d => (d.treeID, d.nodeData)).groupByKey().map { case (treeID: Int, nodeData: Iterable[NodeData]) => - treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType) + treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType, classWeights) } val rootNodes: Array[Node] = rootNodesRDD.sortByKey().values.collect() (metadata, treesMetadata.zip(rootNodes), treesWeights) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index b27b2045b22a..f60621d9bce1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -101,16 +101,8 @@ private[ml] trait DecisionTreeParams extends PredictorParams " algorithm will cache node IDs for each instance. Caching can speed up training of deeper" + " trees.") - /** - * An array that stores the weights of class labels. All elements must be non-negative. - * (default = Array()) - * @group expertParam - */ - final val classWeights: DoubleArrayParam = new DoubleArrayParam(this, "classWeights", "An array" + - " that stores the weights of class labels. All elements must be non-negative.") - setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0, - maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10, classWeights -> Array()) + maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10) /** @group setParam */ def setMaxDepth(value: Int): this.type = set(maxDepth, value) @@ -151,12 +143,6 @@ private[ml] trait DecisionTreeParams extends PredictorParams /** @group expertGetParam */ final def getCacheNodeIds: Boolean = $(cacheNodeIds) - /** @group expertSetParam */ - def setClassWeights(value: Array[Double]): this.type = set(classWeights, value) - - /** @group expertGetParam */ - final def getClassWeights: Array[Double] = $(classWeights) - /** * Specifies how often to checkpoint the cached node IDs. * E.g. 10 means that the cache will get checkpointed every 10 iterations. @@ -174,7 +160,8 @@ private[ml] trait DecisionTreeParams extends PredictorParams numClasses: Int, oldAlgo: OldAlgo.Algo, oldImpurity: OldImpurity, - subsamplingRate: Double): OldStrategy = { + subsamplingRate: Double, + classWeights: Array[Double]): OldStrategy = { val strategy = OldStrategy.defaultStrategy(oldAlgo) strategy.impurity = oldImpurity strategy.checkpointInterval = getCheckpointInterval @@ -187,7 +174,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams strategy.numClasses = numClasses strategy.categoricalFeaturesInfo = categoricalFeatures strategy.subsamplingRate = subsamplingRate - strategy.classWeights = getClassWeights + strategy.classWeights = classWeights strategy } } @@ -197,6 +184,14 @@ private[ml] trait DecisionTreeParams extends PredictorParams */ private[ml] trait TreeClassifierParams extends Params { + /** + * An array that stores the weights of class labels. All elements must be non-negative. + * (default = Array(1, 1)) + * @group expertParam + */ + final val classWeights: DoubleArrayParam = new DoubleArrayParam(this, "classWeights", "An array" + + " that stores the weights of class labels. All elements must be non-negative.") + /** * Criterion used for information gain calculation (case-insensitive). * Supported: "entropy" and "gini". @@ -208,7 +203,13 @@ private[ml] trait TreeClassifierParams extends Params { s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}", (value: String) => TreeClassifierParams.supportedImpurities.contains(value.toLowerCase)) - setDefault(impurity -> "gini") + setDefault(impurity -> "gini", classWeights -> Array(1.0, 1.0)) + + /** @group expertSetParam */ + def setClassWeights(value: Array[Double]): this.type = set(classWeights, value) + + /** @group expertGetParam */ + final def getClassWeights: Array[Double] = $(classWeights) /** @group setParam */ def setImpurity(value: String): this.type = set(impurity, value) @@ -327,8 +328,10 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams { categoricalFeatures: Map[Int, Int], numClasses: Int, oldAlgo: OldAlgo.Algo, - oldImpurity: OldImpurity): OldStrategy = { - super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, oldImpurity, getSubsamplingRate) + oldImpurity: OldImpurity, + classWeights: Array[Double]): OldStrategy = { + super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, + oldImpurity, getSubsamplingRate, classWeights) } } @@ -470,7 +473,8 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS private[ml] def getOldBoostingStrategy( categoricalFeatures: Map[Int, Int], oldAlgo: OldAlgo.Algo): OldBoostingStrategy = { - val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2, oldAlgo, OldVariance) + val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2, + oldAlgo, OldVariance, classWeights = Array(1.0, 1.0)) // NOTE: The old API does not support "seed" so we ignore it. new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 255fbd4d6de0..b91752f0ff2c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -191,11 +191,13 @@ private[spark] object ImpurityCalculator { * Create an [[ImpurityCalculator]] instance of the given impurity type and with * the given stats. */ - def getCalculator(impurity: String, stats: Array[Double]): ImpurityCalculator = { + def getCalculator(impurity: String, stats: Array[Double], + classWeights: Array[Double]): ImpurityCalculator = { impurity match { case "gini" => new GiniCalculator(stats) case "entropy" => new EntropyCalculator(stats) case "variance" => new VarianceCalculator(stats) + case "weightedgini" => new WeightedGiniCalculator(stats, classWeights) case _ => throw new IllegalArgumentException( s"ImpurityCalculator builder did not recognize impurity type: $impurity") diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 7a946bcf7f5b..afceb1cb5a0d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -96,7 +96,7 @@ class DecisionTreeClassifierSuite .setMaxDepth(3) .setMaxBins(400) .setMinInstancesPerNode(1) - .setClassWeights(Array(1, 1)) + .setClassWeights(Array(1, 1000)) val categoricalFeatures: Map[Int, Int] = Map() val newData: DataFrame = TreeTests.setMetadata(IrIsRDD, categoricalFeatures, 2) val newTree = dt.fit(newData) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 2e99ee157ae9..a00ee857c10a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -234,7 +234,7 @@ private object RandomForestClassifierSuite extends SparkFunSuite { numClasses: Int): Unit = { val numFeatures = data.first().features.size val oldStrategy = - rf.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, rf.getOldImpurity) + rf.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, rf.getOldImpurity, rf.getClassWeights) val oldModel = OldRandomForest.trainClassifier( data.map(OldLabeledPoint.fromML), oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index c08335f9f84a..2cdcfb9c6b08 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -140,7 +140,8 @@ private object RandomForestRegressorSuite extends SparkFunSuite { categoricalFeatures: Map[Int, Int]): Unit = { val numFeatures = data.first().features.size val oldStrategy = - rf.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, rf.getOldImpurity) + rf.getOldStrategy(categoricalFeatures, numClasses = 0, + OldAlgo.Regression, rf.getOldImpurity, classWeights = Array()) val oldModel = OldRandomForest.trainRegressor(data.map(OldLabeledPoint.fromML), oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) From ad55f6df874aed3eb74ccacb896dd846a9cc9544 Mon Sep 17 00:00:00 2001 From: Yuewei Na Date: Tue, 14 Jun 2016 17:07:19 -0700 Subject: [PATCH 07/18] add SetClassWeights to RandomForestClassifier --- .../ml/classification/RandomForestClassifier.scala | 3 +++ .../classification/DecisionTreeClassifierSuite.scala | 11 ++++++----- .../classification/RandomForestClassifierSuite.scala | 2 +- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index d6a1b54397a8..e14a9a7ea80d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -98,6 +98,9 @@ class RandomForestClassifier @Since("1.4.0") ( override def setFeatureSubsetStrategy(value: String): this.type = super.setFeatureSubsetStrategy(value) + @Since("2.0.0") + override def setClassWeights(value: Array[Double]): this.type = super.setClassWeights(value) + override protected def train(dataset: Dataset[_]): RandomForestClassificationModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index afceb1cb5a0d..6b2f76bb7689 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.{DataFrame, Row} import scala.io.Source import org.apache.spark.mllib.evaluation.MulticlassMetrics + class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -96,14 +97,14 @@ class DecisionTreeClassifierSuite .setMaxDepth(3) .setMaxBins(400) .setMinInstancesPerNode(1) - .setClassWeights(Array(1, 1000)) + .setClassWeights(Array(1, 1)) val categoricalFeatures: Map[Int, Int] = Map() val newData: DataFrame = TreeTests.setMetadata(IrIsRDD, categoricalFeatures, 2) val newTree = dt.fit(newData) val predoutput = newTree.transform(newData) // scalastyle:off println - predoutput.show(1000) - println(newTree.toDebugString) + //predoutput.show(1000) + //println(newTree.toDebugString) // scalastyle:on println /* val predictionsAndLabels = predoutput.select("prediction", "label") @@ -148,7 +149,7 @@ class DecisionTreeClassifierSuite val numClasses = 2 compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures, numClasses) } -/* + test("Binary classification stump with fixed labels 0,1 for Entropy,Gini") { val dt = new DecisionTreeClassifier() .setMaxDepth(3) @@ -160,7 +161,7 @@ class DecisionTreeClassifierSuite compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) } } - }*/ + } test("Multiclass classification stump with 3-ary (unordered) categorical features") { val rdd = categoricalDataPointsForMulticlassRDD diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index a00ee857c10a..ee297223382c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -210,7 +210,7 @@ class RandomForestClassifierSuite assert(model.numClasses === model2.numClasses) } - val rf = new RandomForestClassifier().setNumTrees(2) + val rf = new RandomForestClassifier().setNumTrees(2).setClassWeights(Array()) val rdd = TreeTests.getTreeReadWriteData(sc) val allParamSettings = TreeTests.allParamSettings ++ Map("impurity" -> "entropy") From c17067fca7449c8e5b2b6326ef3b56087f737c6e Mon Sep 17 00:00:00 2001 From: Yuewei Na Date: Wed, 15 Jun 2016 12:16:32 -0700 Subject: [PATCH 08/18] change code style such that requirements are met --- .../org/apache/spark/ml/tree/treeModels.scala | 8 +-- .../org/apache/spark/ml/tree/treeParams.scala | 16 +++--- .../spark/mllib/tree/impurity/Entropy.scala | 4 +- .../spark/mllib/tree/impurity/Variance.scala | 4 +- .../DecisionTreeClassifierSuite.scala | 53 +++++-------------- .../RandomForestClassifierSuite.scala | 3 +- 6 files changed, 32 insertions(+), 56 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index 9451e0ade56b..2299d9744f1f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -447,8 +447,9 @@ private[ml] object EnsembleModelReadWrite { val classWeights: Array[Double] = { val classWeightsJson: JValue = metadata.getParamValue("classWeights") - val classWeightsVector = Param.jsonDecode[Vector](compact(render(classWeightsJson))) - classWeightsVector.toArray + val classWeightsArray = compact(render(classWeightsJson)).split("\\[|,|\\]") + .filter((s: String) => s.length() != 0).map((s: String) => s.toDouble) + classWeightsArray } val treesMetadataPath = new Path(path, "treesMetadata").toString @@ -468,7 +469,8 @@ private[ml] object EnsembleModelReadWrite { val rootNodesRDD: RDD[(Int, Node)] = nodeData.rdd.map(d => (d.treeID, d.nodeData)).groupByKey().map { case (treeID: Int, nodeData: Iterable[NodeData]) => - treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType, classWeights) + treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, + impurityType, classWeights) } val rootNodes: Array[Node] = rootNodesRDD.sortByKey().values.collect() (metadata, treesMetadata.zip(rootNodes), treesWeights) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index f60621d9bce1..5de2e89a018c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -18,12 +18,13 @@ package org.apache.spark.ml.tree import scala.util.Try + import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} -import org.apache.spark.mllib.tree.impurity.{WeightedGini, Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} +import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance, WeightedGini} import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError} import org.apache.spark.sql.types.{DataType, DoubleType, StructType} @@ -185,16 +186,16 @@ private[ml] trait DecisionTreeParams extends PredictorParams private[ml] trait TreeClassifierParams extends Params { /** - * An array that stores the weights of class labels. All elements must be non-negative. - * (default = Array(1, 1)) - * @group expertParam - */ + * An array that stores the weights of class labels. All elements must be non-negative. + * (default = Array(1, 1)) + * @group expertParam + */ final val classWeights: DoubleArrayParam = new DoubleArrayParam(this, "classWeights", "An array" + " that stores the weights of class labels. All elements must be non-negative.") /** * Criterion used for information gain calculation (case-insensitive). - * Supported: "entropy" and "gini". + * Supported: "entropy", "gini" and "weightedgini". * (default = gini) * @group param */ @@ -233,7 +234,8 @@ private[ml] trait TreeClassifierParams extends Params { private[ml] object TreeClassifierParams { // These options should be lowercase. - final val supportedImpurities: Array[String] = Array("entropy", "gini", "weightedgini").map(_.toLowerCase) + final val supportedImpurities: Array[String] = Array("entropy", "gini", "weightedgini") + .map(_.toLowerCase) } private[ml] trait DecisionTreeClassifierParams diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index d9f089842e3b..de24ba844451 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -139,8 +139,8 @@ private[spark] class EntropyCalculator(stats: Array[Double]) extends ImpurityCal def count: Long = stats.sum.toLong /** - * Weighted summary statistics of data points, which in this case assume uniform class weights - */ + * Weighted summary statistics of data points, which in this case assume uniform class weights + */ def weightedCount: Double = stats.sum /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 9227cc14540a..1087139fb4bd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -123,8 +123,8 @@ private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCa def count: Long = stats(0).toLong /** - * Weighted summary statistics of data points, which in this case assume uniform class weights - */ + * Weighted summary statistics of data points, which in this case assume uniform class weights + */ def weightedCount: Double = stats(0) /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 6b2f76bb7689..3592ce7f23e7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.classification +import scala.io.Source + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} @@ -30,8 +32,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} -import scala.io.Source -import org.apache.spark.mllib.evaluation.MulticlassMetrics + class DecisionTreeClassifierSuite @@ -74,22 +75,23 @@ class DecisionTreeClassifierSuite ///////////////////////////////////////////////////////////////////////////// test("classification with class weights") { - val buf = Source.fromFile("/Users/yueweina/Documents/spark/data/mllib/unbalenced_train.csv").getLines + val buf = Source.fromFile("/Users/yueweina/Documents/spark/data/mllib/unbalenced_train.csv") + .getLines val header = buf.take(1).next() var data = new Array[LabeledPoint](1000) var idx : Int = 0 for (row <- buf) { val cols = row.split(",").map(_.trim) - // scalastyle:off println - //println(s"${cols(0)}|${cols(1)}|${cols(2)}|${cols(3)}|${cols(4)}") - //println( cols(0).getClass) - // scalastyle:on println + // scala style: off println + // println(s"${cols(0)}|${cols(1)}|${cols(2)}|${cols(3)}|${cols(4)}") + // println( cols(0).getClass) + // scala style: on println data(idx) = new LabeledPoint(cols(0).toDouble, Vectors.dense(cols(1).toDouble, cols(2).toDouble)) idx += 1 } // scalastyle:off println - //println(data) + // println(data) // scalastyle:on println val IrIsRDD = sc.parallelize(data) val dt = new DecisionTreeClassifier() @@ -103,40 +105,9 @@ class DecisionTreeClassifierSuite val newTree = dt.fit(newData) val predoutput = newTree.transform(newData) // scalastyle:off println - //predoutput.show(1000) - //println(newTree.toDebugString) + // predoutput.show(1000) + // println(newTree.toDebugString) // scalastyle:on println - /* - val predictionsAndLabels = predoutput.select("prediction", "label") - .map(row => (row.getDouble(0), row.getDouble(1))) - val metrics = new MulticlassMetrics(predictionsAndLabels) - val confusionMatrix = metrics.confusionMatrix - // scalastyle:off println - println(confusionMatrix.toString()) - // scalastyle:on println - val buf2 = Source.fromFile("/Users/yueweina/Documents/spark/data/mllib/unbalenced_test.csv").getLines - val header2 = buf2.take(1).next() - var data2 = new Array[LabeledPoint](250) - idx = 0 - for (row <- buf2) { - val cols = row.split(",").map(_.trim) - data2(idx) = new LabeledPoint(cols(0).toDouble, - Vectors.dense(cols(1).toDouble, cols(2).toDouble)) - idx += 1 - } - // scalastyle:off println - //println(data) - // scalastyle:on println - val IrIsRDD2 = sc.parallelize(data2) - val newData2: DataFrame = TreeTests.setMetadata(IrIsRDD2, categoricalFeatures, 2) - val predoutput2 = newTree.transform(newData2) - val predictions = predoutput2.select("prediction", "label") - .map(row => (row.getDouble(0), row.getDouble(1))) - val metrics2 = new MulticlassMetrics(predictions) - val confusionMatrix2 = metrics2.confusionMatrix - // scalastyle:off println - println(confusionMatrix2.toString()) - // scalastyle:on println */ } test("Binary classification stump with ordered categorical features") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index ee297223382c..eabc773aeeb3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -234,7 +234,8 @@ private object RandomForestClassifierSuite extends SparkFunSuite { numClasses: Int): Unit = { val numFeatures = data.first().features.size val oldStrategy = - rf.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, rf.getOldImpurity, rf.getClassWeights) + rf.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, + rf.getOldImpurity, rf.getClassWeights) val oldModel = OldRandomForest.trainClassifier( data.map(OldLabeledPoint.fromML), oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt) From 17635d99574c2fbb7bd684899c465c0289edb5ad Mon Sep 17 00:00:00 2001 From: Yuewei Na Date: Wed, 15 Jun 2016 16:57:02 -0700 Subject: [PATCH 09/18] random forest with class weights runs properly --- .../RandomForestClassifier.scala | 3 +- .../DecisionTreeClassifierSuite.scala | 4 +- .../RandomForestClassifierSuite.scala | 40 +++++++++++++++++++ 3 files changed, 44 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index e14a9a7ea80d..e68b70807181 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -199,7 +199,8 @@ class RandomForestClassificationModel private[ml] ( // Ignore the tree weights since all are 1.0 for now. val votes = Array.fill[Double](numClasses)(0.0) _trees.view.foreach { tree => - val classCounts: Array[Double] = tree.rootNode.predictImpl(features).impurityStats.stats + val classCounts: Array[Double] = + tree.rootNode.predictImpl(features).impurityStats.weightedStats val total = classCounts.sum if (total != 0) { var i = 0 diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 3592ce7f23e7..84663724afb8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -99,13 +99,13 @@ class DecisionTreeClassifierSuite .setMaxDepth(3) .setMaxBins(400) .setMinInstancesPerNode(1) - .setClassWeights(Array(1, 1)) + .setClassWeights(Array(1, 1000)) val categoricalFeatures: Map[Int, Int] = Map() val newData: DataFrame = TreeTests.setMetadata(IrIsRDD, categoricalFeatures, 2) val newTree = dt.fit(newData) val predoutput = newTree.transform(newData) // scalastyle:off println - // predoutput.show(1000) + predoutput.show(1000) // println(newTree.toDebugString) // scalastyle:on println } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index eabc773aeeb3..0b851cbc20ce 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.classification +import scala.io.Source + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} @@ -32,6 +34,7 @@ import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} + /** * Test suite for [[RandomForestClassifier]]. */ @@ -76,6 +79,43 @@ class RandomForestClassifierSuite ParamsSuite.checkParams(model) } + test("classification with class weights") { + val buf = Source.fromFile("/Users/yueweina/Documents/spark/data/mllib/unbalenced_train.csv") + .getLines + val header = buf.take(1).next() + var data = new Array[LabeledPoint](1000) + var idx : Int = 0 + for (row <- buf) { + val cols = row.split(",").map(_.trim) + // scala style: off println + // println(s"${cols(0)}|${cols(1)}|${cols(2)}|${cols(3)}|${cols(4)}") + // println( cols(0).getClass) + // scala style: on println + data(idx) = new LabeledPoint(cols(0).toDouble, + Vectors.dense(cols(1).toDouble, cols(2).toDouble)) + idx += 1 + } + // scalastyle:off println + // println(data) + // scalastyle:on println + val IrIsRDD = sc.parallelize(data) + val rf = new RandomForestClassifier() + .setImpurity("weightedgini") + .setMaxDepth(3) + .setMaxBins(400) + .setMinInstancesPerNode(1) + .setClassWeights(Array(1, 1000)) + .setSubsamplingRate(1) + val categoricalFeatures: Map[Int, Int] = Map() + val newData: DataFrame = TreeTests.setMetadata(IrIsRDD, categoricalFeatures, 2) + val newTree = rf.fit(newData) + val predoutput = newTree.transform(newData) + // scalastyle:off println + predoutput.show(1000) + //println(newTree.toDebugString) + // scalastyle:on println + } + test("Binary classification with continuous features:" + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { val rf = new RandomForestClassifier() From 9d52c1f4973e3ef8770cbe0d61b7b1ae67041f68 Mon Sep 17 00:00:00 2001 From: Yuewei Na Date: Thu, 16 Jun 2016 13:58:49 -0700 Subject: [PATCH 10/18] minor changes --- .../DecisionTreeClassifier.scala | 2 +- .../ml/regression/RandomForestRegressor.scala | 2 +- .../DecisionTreeClassifierSuite.scala | 44 ++++--------------- .../RandomForestClassifierSuite.scala | 37 ---------------- 4 files changed, 11 insertions(+), 74 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 53aaeb2fbf3a..59aaa1cd457a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -132,7 +132,7 @@ class DecisionTreeClassifier @Since("1.4.0") ( @Since("1.4.0") @Experimental object DecisionTreeClassifier extends DefaultParamsReadable[DecisionTreeClassifier] { - /** Accessor for supported impurities: entropy, gini */ + /** Accessor for supported impurities: entropy, gini, weightedgini */ @Since("1.4.0") final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 1fd42303b60c..5b2b5a2d7a20 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -99,7 +99,7 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 0, - OldAlgo.Regression, getOldImpurity, Array()) + OldAlgo.Regression, getOldImpurity, classWeights = Array(1, 1)) val instr = Instrumentation.create(this, oldDataset) instr.logParams(params: _*) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 84663724afb8..c26f8737bdf7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.ml.classification -import scala.io.Source - import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} @@ -74,40 +72,16 @@ class DecisionTreeClassifierSuite // Tests calling train() ///////////////////////////////////////////////////////////////////////////// - test("classification with class weights") { - val buf = Source.fromFile("/Users/yueweina/Documents/spark/data/mllib/unbalenced_train.csv") - .getLines - val header = buf.take(1).next() - var data = new Array[LabeledPoint](1000) - var idx : Int = 0 - for (row <- buf) { - val cols = row.split(",").map(_.trim) - // scala style: off println - // println(s"${cols(0)}|${cols(1)}|${cols(2)}|${cols(3)}|${cols(4)}") - // println( cols(0).getClass) - // scala style: on println - data(idx) = new LabeledPoint(cols(0).toDouble, - Vectors.dense(cols(1).toDouble, cols(2).toDouble)) - idx += 1 - } - // scalastyle:off println - // println(data) - // scalastyle:on println - val IrIsRDD = sc.parallelize(data) + test("Binary classification with setting explicit uniform class weights") { val dt = new DecisionTreeClassifier() - .setImpurity("weightedgini") - .setMaxDepth(3) - .setMaxBins(400) - .setMinInstancesPerNode(1) - .setClassWeights(Array(1, 1000)) - val categoricalFeatures: Map[Int, Int] = Map() - val newData: DataFrame = TreeTests.setMetadata(IrIsRDD, categoricalFeatures, 2) - val newTree = dt.fit(newData) - val predoutput = newTree.transform(newData) - // scalastyle:off println - predoutput.show(1000) - // println(newTree.toDebugString) - // scalastyle:on println + .setImpurity("WeightedGini") + .setMaxDepth(2) + .setMaxBins(100) + .setSeed(1) + .setClassWeights(Array(1, 1)) + val categoricalFeatures = Map(0 -> 3, 1 -> 3) + val numClasses = 2 + compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures, numClasses) } test("Binary classification stump with ordered categorical features") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 0b851cbc20ce..3f25adc64b9e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -79,43 +79,6 @@ class RandomForestClassifierSuite ParamsSuite.checkParams(model) } - test("classification with class weights") { - val buf = Source.fromFile("/Users/yueweina/Documents/spark/data/mllib/unbalenced_train.csv") - .getLines - val header = buf.take(1).next() - var data = new Array[LabeledPoint](1000) - var idx : Int = 0 - for (row <- buf) { - val cols = row.split(",").map(_.trim) - // scala style: off println - // println(s"${cols(0)}|${cols(1)}|${cols(2)}|${cols(3)}|${cols(4)}") - // println( cols(0).getClass) - // scala style: on println - data(idx) = new LabeledPoint(cols(0).toDouble, - Vectors.dense(cols(1).toDouble, cols(2).toDouble)) - idx += 1 - } - // scalastyle:off println - // println(data) - // scalastyle:on println - val IrIsRDD = sc.parallelize(data) - val rf = new RandomForestClassifier() - .setImpurity("weightedgini") - .setMaxDepth(3) - .setMaxBins(400) - .setMinInstancesPerNode(1) - .setClassWeights(Array(1, 1000)) - .setSubsamplingRate(1) - val categoricalFeatures: Map[Int, Int] = Map() - val newData: DataFrame = TreeTests.setMetadata(IrIsRDD, categoricalFeatures, 2) - val newTree = rf.fit(newData) - val predoutput = newTree.transform(newData) - // scalastyle:off println - predoutput.show(1000) - //println(newTree.toDebugString) - // scalastyle:on println - } - test("Binary classification with continuous features:" + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { val rf = new RandomForestClassifier() From bf1acfdb7293ee63b31438972c25de683c852e7d Mon Sep 17 00:00:00 2001 From: Yuewei Na Date: Thu, 16 Jun 2016 17:27:41 -0700 Subject: [PATCH 11/18] add Strategy new param doc --- .../org/apache/spark/ml/regression/RandomForestRegressor.scala | 2 +- .../org/apache/spark/mllib/tree/configuration/Strategy.scala | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 5b2b5a2d7a20..1fd42303b60c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -99,7 +99,7 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 0, - OldAlgo.Regression, getOldImpurity, classWeights = Array(1, 1)) + OldAlgo.Regression, getOldImpurity, Array()) val instr = Instrumentation.create(this, oldDataset) instr.logParams(params: _*) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 03667a5dac71..204075ec98e3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -66,6 +66,8 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance, * E.g. 10 means that the cache will get checkpointed every 10 updates. If * the checkpoint directory is not set in * [[org.apache.spark.SparkContext]], this setting is ignored. + * @param classWeights Weights of classes used in classification problems. It will be ignored in + * regression problems. */ @Since("1.0.0") class Strategy @Since("1.3.0") ( From 2baf814cfec0cac97390e7b64f9064e0bc378d7a Mon Sep 17 00:00:00 2001 From: Yuewei Na Date: Fri, 17 Jun 2016 11:00:49 -0700 Subject: [PATCH 12/18] move classW def to DeTrParam class --- .../org/apache/spark/ml/tree/treeParams.scala | 35 ++++++++++--------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 5de2e89a018c..80533c32274d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -102,8 +102,17 @@ private[ml] trait DecisionTreeParams extends PredictorParams " algorithm will cache node IDs for each instance. Caching can speed up training of deeper" + " trees.") + /** + * An array that stores the weights of class labels. All elements must be non-negative. + * (default = Array(1, 1)) + * @group Param + */ + final val classWeights: DoubleArrayParam = new DoubleArrayParam(this, "classWeights", "An array" + + " that stores the weights of class labels. All elements must be non-negative.") + setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0, - maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10) + maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10, + classWeights -> Array(1.0, 1.0)) /** @group setParam */ def setMaxDepth(value: Int): this.type = set(maxDepth, value) @@ -144,6 +153,12 @@ private[ml] trait DecisionTreeParams extends PredictorParams /** @group expertGetParam */ final def getCacheNodeIds: Boolean = $(cacheNodeIds) + /** @group SetParam */ + def setClassWeights(value: Array[Double]): this.type = set(classWeights, value) + + /** @group GetParam */ + final def getClassWeights: Array[Double] = $(classWeights) + /** * Specifies how often to checkpoint the cached node IDs. * E.g. 10 means that the cache will get checkpointed every 10 iterations. @@ -185,14 +200,6 @@ private[ml] trait DecisionTreeParams extends PredictorParams */ private[ml] trait TreeClassifierParams extends Params { - /** - * An array that stores the weights of class labels. All elements must be non-negative. - * (default = Array(1, 1)) - * @group expertParam - */ - final val classWeights: DoubleArrayParam = new DoubleArrayParam(this, "classWeights", "An array" + - " that stores the weights of class labels. All elements must be non-negative.") - /** * Criterion used for information gain calculation (case-insensitive). * Supported: "entropy", "gini" and "weightedgini". @@ -204,13 +211,7 @@ private[ml] trait TreeClassifierParams extends Params { s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}", (value: String) => TreeClassifierParams.supportedImpurities.contains(value.toLowerCase)) - setDefault(impurity -> "gini", classWeights -> Array(1.0, 1.0)) - - /** @group expertSetParam */ - def setClassWeights(value: Array[Double]): this.type = set(classWeights, value) - - /** @group expertGetParam */ - final def getClassWeights: Array[Double] = $(classWeights) + setDefault(impurity -> "gini") /** @group setParam */ def setImpurity(value: String): this.type = set(impurity, value) @@ -331,7 +332,7 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams { numClasses: Int, oldAlgo: OldAlgo.Algo, oldImpurity: OldImpurity, - classWeights: Array[Double]): OldStrategy = { + classWeights: Array[Double] = Array()): OldStrategy = { super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, oldImpurity, getSubsamplingRate, classWeights) } From fe3819c3434d41c20e4625d81ca3da8977bdc67e Mon Sep 17 00:00:00 2001 From: Yuewei Na Date: Fri, 17 Jun 2016 11:26:40 -0700 Subject: [PATCH 13/18] remove @BeanProperty --- .../org/apache/spark/mllib/tree/configuration/Strategy.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 204075ec98e3..d188186546f7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -84,7 +84,7 @@ class Strategy @Since("1.3.0") ( @Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1, @Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false, @Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10, - @Since("2.0.0") @BeanProperty var classWeights: Array[Double] = Array(1, 1)) + @Since("2.0.0") var classWeights: Array[Int] = Array(1, 1)) extends Serializable { /** From fd2eee567deb3308e6184f4ecb20f681f9fa9353 Mon Sep 17 00:00:00 2001 From: Yuewei Na Date: Fri, 17 Jun 2016 12:10:42 -0700 Subject: [PATCH 14/18] change getOldImpu interfaces --- .../scala/org/apache/spark/ml/tree/treeParams.scala | 13 ++++++------- .../spark/mllib/tree/configuration/Strategy.scala | 2 +- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 80533c32274d..87dea5f8d47e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -176,8 +176,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams numClasses: Int, oldAlgo: OldAlgo.Algo, oldImpurity: OldImpurity, - subsamplingRate: Double, - classWeights: Array[Double]): OldStrategy = { + subsamplingRate: Double): OldStrategy = { val strategy = OldStrategy.defaultStrategy(oldAlgo) strategy.impurity = oldImpurity strategy.checkpointInterval = getCheckpointInterval @@ -190,7 +189,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams strategy.numClasses = numClasses strategy.categoricalFeaturesInfo = categoricalFeatures strategy.subsamplingRate = subsamplingRate - strategy.classWeights = classWeights + strategy.classWeights = getClassWeights strategy } } @@ -331,10 +330,9 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams { categoricalFeatures: Map[Int, Int], numClasses: Int, oldAlgo: OldAlgo.Algo, - oldImpurity: OldImpurity, - classWeights: Array[Double] = Array()): OldStrategy = { + oldImpurity: OldImpurity): OldStrategy = { super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, - oldImpurity, getSubsamplingRate, classWeights) + oldImpurity, getSubsamplingRate) } } @@ -477,7 +475,8 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS categoricalFeatures: Map[Int, Int], oldAlgo: OldAlgo.Algo): OldBoostingStrategy = { val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2, - oldAlgo, OldVariance, classWeights = Array(1.0, 1.0)) + oldAlgo, OldVariance) + // NOTE: The old API does not support "seed" so we ignore it. new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index d188186546f7..e75f617cd5a6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -84,7 +84,7 @@ class Strategy @Since("1.3.0") ( @Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1, @Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false, @Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10, - @Since("2.0.0") var classWeights: Array[Int] = Array(1, 1)) + @Since("2.0.0") var classWeights: Array[Double] = Array(1, 1)) extends Serializable { /** From 455c47e274e1dff50268a6c07b2f0a67c32ac24c Mon Sep 17 00:00:00 2001 From: Yuewei Na Date: Mon, 20 Jun 2016 10:22:19 -0700 Subject: [PATCH 15/18] first version that pass all run_test, by adding redundant constructor and reverting getOldStrategy to old versions --- .../DecisionTreeClassifier.scala | 2 +- .../RandomForestClassifier.scala | 2 +- .../ml/regression/DecisionTreeRegressor.scala | 2 +- .../ml/regression/RandomForestRegressor.scala | 2 +- .../mllib/tree/configuration/Strategy.scala | 25 ++++++++++++++++++- .../RandomForestClassifierSuite.scala | 2 +- .../RandomForestRegressorSuite.scala | 2 +- scalastyle-config.xml | 2 +- 8 files changed, 31 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 59aaa1cd457a..52093b358e2f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -122,7 +122,7 @@ class DecisionTreeClassifier @Since("1.4.0") ( categoricalFeatures: Map[Int, Int], numClasses: Int): OldStrategy = { super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity, - subsamplingRate = 1.0, getClassWeights) + subsamplingRate = 1.0) } @Since("1.4.1") diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index e68b70807181..16afb9e57a12 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -108,7 +108,7 @@ class RandomForestClassifier @Since("1.4.0") ( val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses) val strategy = super.getOldStrategy(categoricalFeatures, numClasses, - OldAlgo.Classification, getOldImpurity, getClassWeights) + OldAlgo.Classification, getOldImpurity) val instr = Instrumentation.create(this, oldDataset) instr.logParams(params: _*) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index b2fe5ded6179..c4df9d11127f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -117,7 +117,7 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S /** (private[ml]) Create a Strategy instance to use with the old API. */ private[ml] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = { super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity, - subsamplingRate = 1.0, classWeights = Array()) + subsamplingRate = 1.0) } @Since("1.4.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 1fd42303b60c..ba29be16561f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -99,7 +99,7 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 0, - OldAlgo.Regression, getOldImpurity, Array()) + OldAlgo.Regression, getOldImpurity) val instr = Instrumentation.create(this, oldDataset) instr.logParams(params: _*) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index e75f617cd5a6..30cac57451e3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -84,7 +84,7 @@ class Strategy @Since("1.3.0") ( @Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1, @Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false, @Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10, - @Since("2.0.0") var classWeights: Array[Double] = Array(1, 1)) + @Since("2.0.0") @BeanProperty var classWeights: Array[Double] = Array(1.0, 1.0)) extends Serializable { /** @@ -101,6 +101,29 @@ class Strategy @Since("1.3.0") ( isMulticlassClassification && (categoricalFeaturesInfo.size > 0) } + /** + * Make the class compatible with previous versions + */ + @Since("2.0.0") + def this( + algo: Algo, + impurity: Impurity, + maxDepth: Int, + numClasses: Int, + maxBins: Int, + quantileCalculationStrategy: QuantileStrategy, + categoricalFeaturesInfo: Map[Int, Int], + minInstancesPerNode: Int, + minInfoGain: Double, + maxMemoryInMB: Int, + subsamplingRate: Double, + useNodeIdCache: Boolean, + checkpointInterval: Int) { + this(algo, impurity, maxDepth, numClasses, maxBins, quantileCalculationStrategy, + categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, maxMemoryInMB, + subsamplingRate, useNodeIdCache, checkpointInterval, Array()) + } + /** * Java-friendly constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]] */ diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 3f25adc64b9e..6040e77f797a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -238,7 +238,7 @@ private object RandomForestClassifierSuite extends SparkFunSuite { val numFeatures = data.first().features.size val oldStrategy = rf.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, - rf.getOldImpurity, rf.getClassWeights) + rf.getOldImpurity) val oldModel = OldRandomForest.trainClassifier( data.map(OldLabeledPoint.fromML), oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index 2cdcfb9c6b08..2c15822252c1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -141,7 +141,7 @@ private object RandomForestRegressorSuite extends SparkFunSuite { val numFeatures = data.first().features.size val oldStrategy = rf.getOldStrategy(categoricalFeatures, numClasses = 0, - OldAlgo.Regression, rf.getOldImpurity, classWeights = Array()) + OldAlgo.Regression, rf.getOldImpurity) val oldModel = OldRandomForest.trainRegressor(data.map(OldLabeledPoint.fromML), oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 270104f85b83..57c275baed21 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -94,7 +94,7 @@ This file is divided into 3 sections: - + From 9c99973476c9143535a913761ceced3ab1d73541 Mon Sep 17 00:00:00 2001 From: Yuewei Na Date: Tue, 21 Jun 2016 11:36:40 -0700 Subject: [PATCH 16/18] first version that pass all tests --- .../DecisionTreeClassifier.scala | 2 +- .../RandomForestClassifier.scala | 5 +- .../ml/regression/DecisionTreeRegressor.scala | 2 +- .../ml/regression/RandomForestRegressor.scala | 3 +- .../org/apache/spark/ml/tree/treeParams.scala | 94 ++++++++++++++----- .../mllib/tree/configuration/Strategy.scala | 2 +- .../DecisionTreeClassifierSuite.scala | 2 +- .../RandomForestClassifierSuite.scala | 2 +- .../RandomForestRegressorSuite.scala | 3 +- 9 files changed, 83 insertions(+), 32 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 52093b358e2f..59aaa1cd457a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -122,7 +122,7 @@ class DecisionTreeClassifier @Since("1.4.0") ( categoricalFeatures: Map[Int, Int], numClasses: Int): OldStrategy = { super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity, - subsamplingRate = 1.0) + subsamplingRate = 1.0, getClassWeights) } @Since("1.4.1") diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 16afb9e57a12..dc6b56fe2b92 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -107,8 +107,9 @@ class RandomForestClassifier @Since("1.4.0") ( val numClasses: Int = getNumClasses(dataset) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses) val strategy = - super.getOldStrategy(categoricalFeatures, numClasses, - OldAlgo.Classification, getOldImpurity) + super.getOldStrategy(categoricalFeatures = categoricalFeatures, numClasses = numClasses, + oldAlgo = OldAlgo.Classification, oldImpurity = getOldImpurity, + subsamplingRate = getSubsamplingRate, classWeights = getClassWeights) val instr = Instrumentation.create(this, oldDataset) instr.logParams(params: _*) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index c4df9d11127f..92cfedbd785b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -117,7 +117,7 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S /** (private[ml]) Create a Strategy instance to use with the old API. */ private[ml] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = { super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity, - subsamplingRate = 1.0) + 1.0, Array()) } @Since("1.4.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index ba29be16561f..4385c7c640d9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -99,7 +99,8 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 0, - OldAlgo.Regression, getOldImpurity) + OldAlgo.Regression, getOldImpurity, getSubsamplingRate, + classWeights = Array()) val instr = Instrumentation.create(this, oldDataset) instr.logParams(params: _*) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 87dea5f8d47e..45f34d7c1a9b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -102,17 +102,8 @@ private[ml] trait DecisionTreeParams extends PredictorParams " algorithm will cache node IDs for each instance. Caching can speed up training of deeper" + " trees.") - /** - * An array that stores the weights of class labels. All elements must be non-negative. - * (default = Array(1, 1)) - * @group Param - */ - final val classWeights: DoubleArrayParam = new DoubleArrayParam(this, "classWeights", "An array" + - " that stores the weights of class labels. All elements must be non-negative.") - setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0, - maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10, - classWeights -> Array(1.0, 1.0)) + maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10) /** @group setParam */ def setMaxDepth(value: Int): this.type = set(maxDepth, value) @@ -153,12 +144,6 @@ private[ml] trait DecisionTreeParams extends PredictorParams /** @group expertGetParam */ final def getCacheNodeIds: Boolean = $(cacheNodeIds) - /** @group SetParam */ - def setClassWeights(value: Array[Double]): this.type = set(classWeights, value) - - /** @group GetParam */ - final def getClassWeights: Array[Double] = $(classWeights) - /** * Specifies how often to checkpoint the cached node IDs. * E.g. 10 means that the cache will get checkpointed every 10 iterations. @@ -176,7 +161,8 @@ private[ml] trait DecisionTreeParams extends PredictorParams numClasses: Int, oldAlgo: OldAlgo.Algo, oldImpurity: OldImpurity, - subsamplingRate: Double): OldStrategy = { + subsamplingRate: Double, + classWeights: Array[Double]): OldStrategy = { val strategy = OldStrategy.defaultStrategy(oldAlgo) strategy.impurity = oldImpurity strategy.checkpointInterval = getCheckpointInterval @@ -189,9 +175,32 @@ private[ml] trait DecisionTreeParams extends PredictorParams strategy.numClasses = numClasses strategy.categoricalFeaturesInfo = categoricalFeatures strategy.subsamplingRate = subsamplingRate - strategy.classWeights = getClassWeights + strategy.classWeights = classWeights strategy } + + /** (private[ml]) Create a Strategy whose interface is compatible with the old API. */ + private[ml] def getOldStrategy( + categoricalFeatures: Map[Int, Int], + numClasses: Int, + oldAlgo: OldAlgo.Algo, + oldImpurity: OldImpurity, + subsamplingRate: Double): OldStrategy = { + val strategy = OldStrategy.defaultStrategy(oldAlgo) + strategy.impurity = oldImpurity + strategy.checkpointInterval = getCheckpointInterval + strategy.maxBins = getMaxBins + strategy.maxDepth = getMaxDepth + strategy.maxMemoryInMB = getMaxMemoryInMB + strategy.minInfoGain = getMinInfoGain + strategy.minInstancesPerNode = getMinInstancesPerNode + strategy.useNodeIdCache = getCacheNodeIds + strategy.numClasses = numClasses + strategy.categoricalFeaturesInfo = categoricalFeatures + strategy.subsamplingRate = subsamplingRate + strategy.classWeights = Array(1.0, 1.0) + strategy + } } /** @@ -210,7 +219,15 @@ private[ml] trait TreeClassifierParams extends Params { s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}", (value: String) => TreeClassifierParams.supportedImpurities.contains(value.toLowerCase)) - setDefault(impurity -> "gini") + /** + * An array that stores the weights of class labels. All elements must be non-negative. + * (default = Array(1.0, 1.0)) + * @group Param + */ + final val classWeights: DoubleArrayParam = new DoubleArrayParam(this, "classWeights", "An array" + + " that stores the weights of class labels. All elements must be non-negative.") + + setDefault(impurity -> "gini", classWeights -> Array(1.0, 1.0)) /** @group setParam */ def setImpurity(value: String): this.type = set(impurity, value) @@ -218,6 +235,12 @@ private[ml] trait TreeClassifierParams extends Params { /** @group getParam */ final def getImpurity: String = $(impurity).toLowerCase + /** @group SetParam */ + def setClassWeights(value: Array[Double]): this.type = set(classWeights, value) + + /** @group GetParam */ + final def getClassWeights: Array[Double] = $(classWeights) + /** Convert new impurity to old impurity. */ private[ml] def getOldImpurity: OldImpurity = { getImpurity match { @@ -257,7 +280,16 @@ private[ml] trait TreeRegressorParams extends Params { s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}", (value: String) => TreeRegressorParams.supportedImpurities.contains(value.toLowerCase)) - setDefault(impurity -> "variance") + /** + * An array that stores the weights of class labels. This parameter will be ignored in + * regression trees. + * (default = Array()) + * @group Param + */ + final val classWeights: DoubleArrayParam = new DoubleArrayParam(this, "classWeights", "An array" + + " that stores the weights of class labels. All elements must be non-negative.") + + setDefault(impurity -> "variance", classWeights -> Array()) /** @group setParam */ def setImpurity(value: String): this.type = set(impurity, value) @@ -265,6 +297,12 @@ private[ml] trait TreeRegressorParams extends Params { /** @group getParam */ final def getImpurity: String = $(impurity).toLowerCase + /** @group SetParam */ + def setClassWeights(value: Array[Double]): this.type = set(classWeights, value) + + /** @group GetParam */ + final def getClassWeights: Array[Double] = $(classWeights) + /** Convert new impurity to old impurity. */ private[ml] def getOldImpurity: OldImpurity = { getImpurity match { @@ -330,9 +368,19 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams { categoricalFeatures: Map[Int, Int], numClasses: Int, oldAlgo: OldAlgo.Algo, - oldImpurity: OldImpurity): OldStrategy = { + oldImpurity: OldImpurity, + classWeights: Array[Double]): OldStrategy = { super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, - oldImpurity, getSubsamplingRate) + oldImpurity, getSubsamplingRate, classWeights) + } + + private[ml] def getOldStrategy( + categoricalFeatures: Map[Int, Int], + numClasses: Int, + oldAlgo: OldAlgo.Algo, + oldImpurity: OldImpurity): OldStrategy = { + super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, + oldImpurity, getSubsamplingRate, Array(1.0, 1.0)) } } @@ -475,7 +523,7 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS categoricalFeatures: Map[Int, Int], oldAlgo: OldAlgo.Algo): OldBoostingStrategy = { val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2, - oldAlgo, OldVariance) + oldAlgo, OldVariance, Array(1.0, 1.0)) // NOTE: The old API does not support "seed" so we ignore it. new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 30cac57451e3..e96350db6bb1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -102,7 +102,7 @@ class Strategy @Since("1.3.0") ( } /** - * Make the class compatible with previous versions + * Make the Strategy class compatible with old API */ @Since("2.0.0") def this( diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index c26f8737bdf7..d4c7dc4db09b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -72,7 +72,7 @@ class DecisionTreeClassifierSuite // Tests calling train() ///////////////////////////////////////////////////////////////////////////// - test("Binary classification with setting explicit uniform class weights") { + test("Binary classification with explicitly setting uniform class weights") { val dt = new DecisionTreeClassifier() .setImpurity("WeightedGini") .setMaxDepth(2) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 6040e77f797a..fd7be885b448 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -238,7 +238,7 @@ private object RandomForestClassifierSuite extends SparkFunSuite { val numFeatures = data.first().features.size val oldStrategy = rf.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, - rf.getOldImpurity) + rf.getOldImpurity, rf.getSubsamplingRate, rf.getClassWeights) val oldModel = OldRandomForest.trainClassifier( data.map(OldLabeledPoint.fromML), oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index 2c15822252c1..169dcdd3f567 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -141,7 +141,8 @@ private object RandomForestRegressorSuite extends SparkFunSuite { val numFeatures = data.first().features.size val oldStrategy = rf.getOldStrategy(categoricalFeatures, numClasses = 0, - OldAlgo.Regression, rf.getOldImpurity) + OldAlgo.Regression, rf.getOldImpurity, rf.getSubsamplingRate, + classWeights = Array()) val oldModel = OldRandomForest.trainRegressor(data.map(OldLabeledPoint.fromML), oldStrategy, rf.getNumTrees, rf.getFeatureSubsetStrategy, rf.getSeed.toInt) val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) From f53a2ccf001ec7db54ebff010fa321e3c116d9a5 Mon Sep 17 00:00:00 2001 From: Yuewei Na Date: Tue, 21 Jun 2016 14:21:27 -0700 Subject: [PATCH 17/18] minor modifications for code styling --- .../RandomForestClassifier.scala | 5 +-- .../ml/regression/DecisionTreeRegressor.scala | 2 +- .../ml/regression/RandomForestRegressor.scala | 5 +-- .../org/apache/spark/ml/tree/treeModels.scala | 4 ++ .../org/apache/spark/ml/tree/treeParams.scala | 40 +++++++++---------- .../mllib/tree/impurity/WeightedGini.scala | 21 +++++----- .../DecisionTreeClassifierSuite.scala | 3 -- .../RandomForestClassifierSuite.scala | 5 +-- 8 files changed, 40 insertions(+), 45 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index dc6b56fe2b92..5e61b759c7c6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -107,9 +107,8 @@ class RandomForestClassifier @Since("1.4.0") ( val numClasses: Int = getNumClasses(dataset) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses) val strategy = - super.getOldStrategy(categoricalFeatures = categoricalFeatures, numClasses = numClasses, - oldAlgo = OldAlgo.Classification, oldImpurity = getOldImpurity, - subsamplingRate = getSubsamplingRate, classWeights = getClassWeights) + super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, + getOldImpurity, getSubsamplingRate, getClassWeights) val instr = Instrumentation.create(this, oldDataset) instr.logParams(params: _*) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 92cfedbd785b..b2fe5ded6179 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -117,7 +117,7 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S /** (private[ml]) Create a Strategy instance to use with the old API. */ private[ml] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = { super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity, - 1.0, Array()) + subsamplingRate = 1.0, classWeights = Array()) } @Since("1.4.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 4385c7c640d9..9429a053804d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -98,9 +98,8 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val strategy = - super.getOldStrategy(categoricalFeatures, numClasses = 0, - OldAlgo.Regression, getOldImpurity, getSubsamplingRate, - classWeights = Array()) + super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, + getOldImpurity, getSubsamplingRate, classWeights = Array()) val instr = Instrumentation.create(this, oldDataset) instr.logParams(params: _*) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index 2299d9744f1f..029ccfec2e2c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -342,6 +342,8 @@ private[ml] object DecisionTreeModelReadWrite { Param.jsonDecode[String](compact(render(impurityJson))) } + // Get class weights to construct ImpurityCalculator. This value + // is ignored unless the impurity is WeightedGini val classWeights: Array[Double] = { val classWeightsJson: JValue = metadata.getParamValue("classWeights") compact(render(classWeightsJson)).split("\\[|,|\\]") @@ -445,6 +447,8 @@ private[ml] object EnsembleModelReadWrite { Param.jsonDecode[String](compact(render(impurityJson))) } + // Get class weights to construct ImpurityCalculator. This value + // is ignored unless the impurity is WeightedGini val classWeights: Array[Double] = { val classWeightsJson: JValue = metadata.getParamValue("classWeights") val classWeightsArray = compact(render(classWeightsJson)).split("\\[|,|\\]") diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 45f34d7c1a9b..aba5ab1aec45 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -155,7 +155,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams */ def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) - /** (private[ml]) Create a Strategy instance to use with the old API. */ + /** (private[ml]) Create a Strategy instance. */ private[ml] def getOldStrategy( categoricalFeatures: Map[Int, Int], numClasses: Int, @@ -181,25 +181,25 @@ private[ml] trait DecisionTreeParams extends PredictorParams /** (private[ml]) Create a Strategy whose interface is compatible with the old API. */ private[ml] def getOldStrategy( - categoricalFeatures: Map[Int, Int], - numClasses: Int, - oldAlgo: OldAlgo.Algo, - oldImpurity: OldImpurity, - subsamplingRate: Double): OldStrategy = { - val strategy = OldStrategy.defaultStrategy(oldAlgo) - strategy.impurity = oldImpurity - strategy.checkpointInterval = getCheckpointInterval - strategy.maxBins = getMaxBins - strategy.maxDepth = getMaxDepth - strategy.maxMemoryInMB = getMaxMemoryInMB - strategy.minInfoGain = getMinInfoGain - strategy.minInstancesPerNode = getMinInstancesPerNode - strategy.useNodeIdCache = getCacheNodeIds - strategy.numClasses = numClasses - strategy.categoricalFeaturesInfo = categoricalFeatures - strategy.subsamplingRate = subsamplingRate - strategy.classWeights = Array(1.0, 1.0) - strategy + categoricalFeatures: Map[Int, Int], + numClasses: Int, + oldAlgo: OldAlgo.Algo, + oldImpurity: OldImpurity, + subsamplingRate: Double): OldStrategy = { + val strategy = OldStrategy.defaultStrategy(oldAlgo) + strategy.impurity = oldImpurity + strategy.checkpointInterval = getCheckpointInterval + strategy.maxBins = getMaxBins + strategy.maxDepth = getMaxDepth + strategy.maxMemoryInMB = getMaxMemoryInMB + strategy.minInfoGain = getMinInfoGain + strategy.minInstancesPerNode = getMinInstancesPerNode + strategy.useNodeIdCache = getCacheNodeIds + strategy.numClasses = numClasses + strategy.categoricalFeaturesInfo = categoricalFeatures + strategy.subsamplingRate = subsamplingRate + strategy.classWeights = Array(1.0, 1.0) + strategy } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/WeightedGini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/WeightedGini.scala index 293ac71dd139..90232d07a691 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/WeightedGini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/WeightedGini.scala @@ -21,11 +21,10 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} /** * :: Experimental :: - * Class for calculating the - * [[http://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity Gini impurity]] - * during binary classification. + * Class for calculating the Gini impurity with class weights using + * altered prior method during classification. */ -@Since("1.0.0") +@Since("2.0.0") @Experimental object WeightedGini extends Impurity { @@ -36,7 +35,7 @@ object WeightedGini extends Impurity { * @param weightedTotalCount sum of counts for all labels * @return information value, or 0 if totalCount = 0 */ - @Since("1.1.0") + @Since("2.0.0") @DeveloperApi override def calculate(weightedCounts: Array[Double], weightedTotalCount: Double): Double = { if (weightedTotalCount == 0) { @@ -61,7 +60,7 @@ object WeightedGini extends Impurity { * @param sumSquares summation of squares of the labels * @return information value, or 0 if count = 0 */ - @Since("1.0.0") + @Since("2.0.0") @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = throw new UnsupportedOperationException("WeightedGini.calculate") @@ -70,7 +69,7 @@ object WeightedGini extends Impurity { * Get this impurity instance. * This is useful for passing impurity parameters to a Strategy in Java. */ - @Since("1.1.0") + @Since("2.0.0") def instance: this.type = this } @@ -179,10 +178,10 @@ private[spark] class WeightedGiniCalculator(stats: Array[Double], classWeights: s"Two ImpurityCalculator instances cannot be added with different counts sizes." + s" Sizes are ${stats.length} and ${other.stats.length}.") val otherCalculator = other.asInstanceOf[WeightedGiniCalculator] + val len = otherCalculator.stats.length var i = 0 - val len = other.stats.length while (i < len) { - stats(i) += other.stats(i) + stats(i) += otherCalculator.stats(i) weightedStats(i) += otherCalculator.weightedStats(i) i += 1 } @@ -198,10 +197,10 @@ private[spark] class WeightedGiniCalculator(stats: Array[Double], classWeights: s"Two ImpurityCalculator instances cannot be subtracted with different counts sizes." + s" Sizes are ${stats.length} and ${other.stats.length}.") val otherCalculator = other.asInstanceOf[WeightedGiniCalculator] + val len = otherCalculator.stats.length var i = 0 - val len = other.stats.length while (i < len) { - stats(i) -= other.stats(i) + stats(i) -= otherCalculator.stats(i) weightedStats(i) -= otherCalculator.weightedStats(i) i += 1 } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index d4c7dc4db09b..096ab2467ab8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -30,9 +30,6 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} - - - class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index fd7be885b448..5ea110ec0d02 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.ml.classification -import scala.io.Source - import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vector, Vectors} @@ -34,7 +32,6 @@ import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} - /** * Test suite for [[RandomForestClassifier]]. */ @@ -213,7 +210,7 @@ class RandomForestClassifierSuite assert(model.numClasses === model2.numClasses) } - val rf = new RandomForestClassifier().setNumTrees(2).setClassWeights(Array()) + val rf = new RandomForestClassifier().setNumTrees(2) val rdd = TreeTests.getTreeReadWriteData(sc) val allParamSettings = TreeTests.allParamSettings ++ Map("impurity" -> "entropy") From ec5f6df4a5c11053f4a57201a09b631a5f8b81cf Mon Sep 17 00:00:00 2001 From: Yuewei Na Date: Wed, 22 Jun 2016 10:42:10 -0700 Subject: [PATCH 18/18] Minor code style modificaions --- .../org/apache/spark/ml/tree/impl/RandomForestSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index b32822477224..dce4e698b82c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -93,7 +93,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, Map(), Set(), Array(6), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0, Array() + 0, 0, 0.0, 0, 0, Array[Double]() ) val featureSamples = Array.fill(200000)(math.random) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) @@ -110,7 +110,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, Map(), Set(), Array(5), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0, Array() + 0, 0, 0.0, 0, 0, Array[Double]() ) val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) @@ -124,7 +124,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, Map(), Set(), Array(3), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0, Array() + 0, 0, 0.0, 0, 0, Array[Double]() ) val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) @@ -138,7 +138,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, Map(), Set(), Array(3), Gini, QuantileStrategy.Sort, - 0, 0, 0.0, 0, 0, Array() + 0, 0, 0.0, 0, 0, Array[Double]() ) val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)