From 50b143a4385f209fbc1793f3e03134cab3ab9583 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 20 Apr 2014 13:33:03 -0700 Subject: [PATCH 01/72] adding support for very deep trees --- .../spark/mllib/tree/DecisionTree.scala | 85 +++++++++++++++++-- .../spark/mllib/tree/DecisionTreeSuite.scala | 12 +-- 2 files changed, 85 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 3019447ce4cd..ad901d4f6739 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -58,7 +58,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) - logDebug("numSplits = " + bins(0).length) + val numBins = bins(0).length + logDebug("numBins = " + numBins) // depth of the decision tree val maxDepth = strategy.maxDepth @@ -72,7 +73,28 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo val parentImpurities = new Array[Double](maxNumNodes) // dummy value for top node (updated during first split calculation) val nodes = new Array[Node](maxNumNodes) + // num features + val numFeatures = input.take(1)(0).features.size + + // Calculate level for single group construction + // Max memory usage for aggregates + val maxMemoryUsage = scala.math.pow(2, 27).toInt //128MB + logDebug("max memory usage for aggregates = " + maxMemoryUsage) + val numElementsPerNode = { + strategy.algo match { + case Classification => 2 * numBins * numFeatures + case Regression => 3 * numBins * numFeatures + } + } + logDebug("numElementsPerNode = " + numElementsPerNode) + val arraySizePerNode = 8 * numElementsPerNode //approx. memory usage for bin aggregate array + val maxNumberOfNodesPerGroup = scala.math.max(maxMemoryUsage / arraySizePerNode, 1) + logDebug("maxNumberOfNodesPerGroup = " + maxNumberOfNodesPerGroup) + // nodes at a level is 2^(level-1). level is zero indexed. + val maxLevelForSingleGroup = scala.math.max( + (scala.math.log(maxNumberOfNodesPerGroup) / scala.math.log(2)).floor.toInt - 1, 0) + logDebug("max level for single group = " + maxLevelForSingleGroup) /* * The main idea here is to perform level-wise training of the decision tree nodes thus @@ -92,7 +114,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Find best split for all nodes at a level. val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, - level, filters, splits, bins) + level, filters, splits, bins, maxLevelForSingleGroup) for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { // Extract info for nodes at the current level. @@ -110,6 +132,10 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo } } + logDebug("#####################################") + logDebug("Extracting tree model") + logDebug("#####################################") + // Initialize the top or root node of the tree. val topNode = nodes(0) // Build the full tree using the node info calculated in the level-wise best split calculations. @@ -260,6 +286,7 @@ object DecisionTree extends Serializable with Logging { * @param filters Filters for all nodes at a given level * @param splits possible splits for all features * @param bins possible bins for all features + * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation. * @return array of splits with best splits for all nodes at a given level. */ protected[tree] def findBestSplits( @@ -269,7 +296,50 @@ object DecisionTree extends Serializable with Logging { level: Int, filters: Array[List[Filter]], splits: Array[Array[Split]], - bins: Array[Array[Bin]]): Array[(Split, InformationGainStats)] = { + bins: Array[Array[Bin]], + maxLevelForSingleGroup: Int): Array[(Split, InformationGainStats)] = { + // split into groups to avoid memory overflow during aggregation + if (level > maxLevelForSingleGroup) { + val numGroups = scala.math.pow(2, (level - maxLevelForSingleGroup)).toInt + logDebug("numGroups = " + numGroups) + var groupIndex = 0 + var bestSplits = new Array[(Split, InformationGainStats)](0) + while (groupIndex < numGroups) { + val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, strategy, level, + filters, splits, bins, numGroups, groupIndex) + bestSplits = Array.concat(bestSplits, bestSplitsForGroup) + groupIndex += 1 + } + bestSplits + } else { + findBestSplitsPerGroup(input, parentImpurities, strategy, level, filters, splits, bins) + } + } + + /** + * Returns an array of optimal splits for a group of nodes at a given level + * + * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + * for DecisionTree + * @param parentImpurities Impurities for all parent nodes for the current level + * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing + * parameters for construction the DecisionTree + * @param level Level of the tree + * @param filters Filters for all nodes at a given level + * @param splits possible splits for all features + * @param bins possible bins for all features + * @return array of splits with best splits for all nodes at a given level. + */ + private def findBestSplitsPerGroup( + input: RDD[LabeledPoint], + parentImpurities: Array[Double], + strategy: Strategy, + level: Int, + filters: Array[List[Filter]], + splits: Array[Array[Split]], + bins: Array[Array[Bin]], + numGroups: Int = 1, + groupIndex: Int = 0): Array[(Split, InformationGainStats)] = { /* * The high-level description for the best split optimizations are noted here. @@ -296,7 +366,7 @@ object DecisionTree extends Serializable with Logging { */ // common calculations for multiple nested methods - val numNodes = scala.math.pow(2, level).toInt + val numNodes = scala.math.pow(2, level).toInt / numGroups logDebug("numNodes = " + numNodes) // Find the number of features by looking at the first sample. val numFeatures = input.first().features.size @@ -304,12 +374,15 @@ object DecisionTree extends Serializable with Logging { val numBins = bins(0).length logDebug("numBins = " + numBins) + // shift when more than one group is used at deep tree level + val groupShift = numNodes * groupIndex + /** Find the filters used before reaching the current code. */ def findParentFilters(nodeIndex: Int): List[Filter] = { if (level == 0) { List[Filter]() } else { - val nodeFilterIndex = scala.math.pow(2, level).toInt - 1 + nodeIndex + val nodeFilterIndex = scala.math.pow(2, level).toInt - 1 + nodeIndex + groupShift filters(nodeFilterIndex) } } @@ -878,7 +951,7 @@ object DecisionTree extends Serializable with Logging { // Iterating over all nodes at this level var node = 0 while (node < numNodes) { - val nodeImpurityIndex = scala.math.pow(2, level).toInt - 1 + node + val nodeImpurityIndex = scala.math.pow(2, level).toInt - 1 + node + groupShift val binsForNode: Array[Double] = getBinDataForNode(node) logDebug("nodeImpurityIndex = " + nodeImpurityIndex) val parentNodeImpurity = parentImpurities(nodeImpurityIndex) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 350130c914f2..e21db8a3bb8c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -254,7 +254,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, - Array[List[Filter]](), splits, bins) + Array[List[Filter]](), splits, bins, 10) val split = bestSplits(0)._1 assert(split.categories.length === 1) @@ -281,7 +281,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, - Array[List[Filter]](), splits, bins) + Array[List[Filter]](), splits, bins, 10) val split = bestSplits(0)._1 assert(split.categories.length === 1) @@ -310,7 +310,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(0).length === 100) val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, - Array[List[Filter]](), splits, bins) + Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._1.threshold === 10) @@ -333,7 +333,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(0).length === 100) val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, - Array[List[Filter]](), splits, bins) + Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._1.threshold === 10) @@ -357,7 +357,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(0).length === 100) val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, - Array[List[Filter]](), splits, bins) + Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._1.threshold === 10) @@ -381,7 +381,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bins(0).length === 100) val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, - Array[List[Filter]](), splits, bins) + Array[List[Filter]](), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) assert(bestSplits(0)._1.threshold === 10) From abc5a23bf80d792a345d723b44bff3ee217cd5ac Mon Sep 17 00:00:00 2001 From: Evan Sparks Date: Mon, 21 Apr 2014 18:41:36 -0700 Subject: [PATCH 02/72] Parameterizing max memory. --- .../scala/org/apache/spark/mllib/tree/DecisionTree.scala | 8 ++++++-- .../apache/spark/mllib/tree/configuration/Strategy.scala | 3 ++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index ad901d4f6739..ffee3fd84895 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -31,6 +31,7 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom +import org.apache.spark.util.Utils.memoryStringToMb import org.apache.spark.mllib.linalg.{Vector, Vectors} /** @@ -79,7 +80,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Calculate level for single group construction // Max memory usage for aggregates - val maxMemoryUsage = scala.math.pow(2, 27).toInt //128MB + val maxMemoryUsage = strategy.maxMemory * 1024 * 1024 logDebug("max memory usage for aggregates = " + maxMemoryUsage) val numElementsPerNode = { strategy.algo match { @@ -1158,10 +1159,13 @@ object DecisionTree extends Serializable with Logging { val maxDepth = options.getOrElse('maxDepth, "1").toString.toInt val maxBins = options.getOrElse('maxBins, "100").toString.toInt + val maxMemUsage = memoryStringToMb(options.getOrElse('maxMemory, "128m").toString) - val strategy = new Strategy(algo, impurity, maxDepth, maxBins) + val strategy = new Strategy(algo, impurity, maxDepth, maxBins, maxMemory=maxMemUsage) val model = DecisionTree.train(trainData, strategy) + + // Load test data. val testData = loadLabeledData(sc, options.get('testDataDir).get.toString) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 8767aca47cd5..fd7a9ed1514c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -43,4 +43,5 @@ class Strategy ( val maxDepth: Int, val maxBins: Int = 100, val quantileCalculationStrategy: QuantileStrategy = Sort, - val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int]()) extends Serializable + val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), + val maxMemory: Int = 128) extends Serializable From 2f1e093c5187a1ed532f9c19b25f8a2a6a46e27a Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 21 Apr 2014 20:49:46 -0700 Subject: [PATCH 03/72] minor: added doc for maxMemory parameter --- .../org/apache/spark/mllib/tree/configuration/Strategy.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index fd7a9ed1514c..18918ad5c746 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -35,6 +35,9 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * k) implies the feature n is categorical with k categories 0, * 1, 2, ... , k-1. It's important to note that features are * zero-indexed. + * @param maxMemory maximum memory in MB allocated to histogram aggregation. Default value is + * 128 MB. + * */ @Experimental class Strategy ( From 02877721328a560f210a7906061108ce5dd4bbbe Mon Sep 17 00:00:00 2001 From: Evan Sparks Date: Tue, 22 Apr 2014 11:13:27 -0700 Subject: [PATCH 04/72] Fixing scalastyle issue. --- .../main/scala/org/apache/spark/mllib/tree/DecisionTree.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index ffee3fd84895..3dd410e933fa 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -89,7 +89,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo } } logDebug("numElementsPerNode = " + numElementsPerNode) - val arraySizePerNode = 8 * numElementsPerNode //approx. memory usage for bin aggregate array + val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array val maxNumberOfNodesPerGroup = scala.math.max(maxMemoryUsage / arraySizePerNode, 1) logDebug("maxNumberOfNodesPerGroup = " + maxNumberOfNodesPerGroup) // nodes at a level is 2^(level-1). level is zero indexed. From 719d0098bb08b50e523cec3e388115d5a206512b Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Wed, 23 Apr 2014 17:04:05 -0700 Subject: [PATCH 05/72] updating user documentation --- docs/mllib-classification-regression.md | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/docs/mllib-classification-regression.md b/docs/mllib-classification-regression.md index 2c42f60c2ecc..b06e799577fb 100644 --- a/docs/mllib-classification-regression.md +++ b/docs/mllib-classification-regression.md @@ -294,12 +294,9 @@ The recursive tree construction is stopped at a node when one of the two conditi 1. The node depth is equal to the `maxDepth` training paramemter 2. No split candidate leads to an information gain at the node. -### Practical Limitations - -The tree implementation stores an Array[Double] of size *O(#features \* #splits \* 2^maxDepth)* in memory for aggregating histograms over partitions. The current implementation might not scale to very deep trees since the memory requirement grows exponentially with tree depth. - -Please drop us a line if you encounter any issues. We are planning to solve this problem in the near future and real-world examples will be great. +### Implementation Details +The tree implementation stores an Array[Double] of size *O(#features \* #splits \* 2^maxDepth)* in memory for aggregating histograms over partitions. Based upon the 'maxMemory' parameter set during training (default is 128 MB), the task is broken down into smaller groups to avoid out-of-memory errors during computation. ## Implementation in MLlib From 15171550fe83e42fcb707744c9035ed540fb78d1 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Tue, 29 Apr 2014 14:45:34 -0700 Subject: [PATCH 06/72] updated documentation --- docs/mllib-decision-tree.md | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index 069376699073..6667911a6aba 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -95,15 +95,9 @@ The recursive tree construction is stopped at a node when one of the two conditi ### Practical limitations -1. The tree implementation stores an Array[Double] of size *O(#features \* #splits \* 2^maxDepth)* - in memory for aggregating histograms over partitions. The current implementation might not scale - to very deep trees since the memory requirement grows exponentially with tree depth. -2. The implemented algorithm reads both sparse and dense data. However, it is not optimized for +1. The implemented algorithm reads both sparse and dense data. However, it is not optimized for sparse input. -3. Python is not supported in this release. - -We are planning to solve these problems in the near future. Please drop us a line if you encounter -any issues. +2. Python is not supported in this release. ## Examples From 718506b2a0146a5794261a553847d363b7dfb932 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Wed, 30 Apr 2014 16:29:24 -0700 Subject: [PATCH 07/72] added unit test --- .../examples/mllib/DecisionTreeRunner.scala | 2 +- .../spark/mllib/tree/DecisionTreeSuite.scala | 64 ++++++++++++++++++- 2 files changed, 64 insertions(+), 2 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 0bd847d7bab3..9832bec90d7e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -51,7 +51,7 @@ object DecisionTreeRunner { algo: Algo = Classification, maxDepth: Int = 5, impurity: ImpurityType = Gini, - maxBins: Int = 20) + maxBins: Int = 100) def main(args: Array[String]) { val defaultParams = Params() diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index e21db8a3bb8c..4a0b399ca3dd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -24,7 +24,8 @@ import org.apache.spark.SparkContext import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} import org.apache.spark.mllib.tree.model.Filter -import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.model.Split +import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.linalg.Vectors @@ -390,6 +391,53 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(bestSplits(0)._2.rightImpurity === 0) assert(bestSplits(0)._2.predict === 1) } + + test("test second level node building with/without groups") { + val arr = DecisionTreeSuite.generateOrderedLabeledPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification, Entropy, 3, 100) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(splits(0).length === 99) + assert(bins.length === 2) + assert(bins(0).length === 100) + assert(splits(0).length === 99) + assert(bins(0).length === 100) + + val leftFilter = Filter(new Split(0,400,FeatureType.Continuous,List()),-1) + val rightFilter = Filter(new Split(0,400,FeatureType.Continuous,List()),1) + val filters = Array[List[Filter]](List(),List(leftFilter),List(rightFilter)) + val parentImpurities = Array(0.5, 0.5, 0.5) + + // Single group second level tree construction. + val bestSplits = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1, filters, + splits, bins, 10) + assert(bestSplits.length === 2) + assert(bestSplits(0)._2.gain > 0) + assert(bestSplits(1)._2.gain > 0) + + // maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second + // level tree construction. + val bestSplitsWithGroups = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1, + filters, splits, bins, 0) + assert(bestSplitsWithGroups.length === 2) + assert(bestSplitsWithGroups(0)._2.gain > 0) + assert(bestSplitsWithGroups(1)._2.gain > 0) + + // Verify whether the splits obtained using single group and multiple group level + // construction strategies are the same. + for (i <- 0 until bestSplits.length) { + assert(bestSplits(i)._1 === bestSplitsWithGroups(i)._1) + assert(bestSplits(i)._2.gain === bestSplitsWithGroups(i)._2.gain) + assert(bestSplits(i)._2.impurity === bestSplitsWithGroups(i)._2.impurity) + assert(bestSplits(i)._2.leftImpurity === bestSplitsWithGroups(i)._2.leftImpurity) + assert(bestSplits(i)._2.rightImpurity === bestSplitsWithGroups(i)._2.rightImpurity) + assert(bestSplits(i)._2.predict === bestSplitsWithGroups(i)._2.predict) + } + + } + } object DecisionTreeSuite { @@ -412,6 +460,20 @@ object DecisionTreeSuite { arr } + def generateOrderedLabeledPoints(): Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](1000) + for (i <- 0 until 1000){ + if (i < 600){ + val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i)) + arr(i) = lp + } else { + val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 1000.0 - i)) + arr(i) = lp + } + } + arr + } + def generateCategoricalDataPoints(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) for (i <- 0 until 1000){ From e0426ee74d5e233c1e7b14e29135015d09a0370c Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Wed, 30 Apr 2014 17:36:47 -0700 Subject: [PATCH 08/72] renamed parameter --- .../main/scala/org/apache/spark/mllib/tree/DecisionTree.scala | 2 +- .../org/apache/spark/mllib/tree/configuration/Strategy.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 6f1f3883a7e8..4af6a827946b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -77,7 +77,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Calculate level for single group construction // Max memory usage for aggregates - val maxMemoryUsage = strategy.maxMemory * 1024 * 1024 + val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024 logDebug("max memory usage for aggregates = " + maxMemoryUsage) val numElementsPerNode = { strategy.algo match { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 18918ad5c746..eeec2f1621cd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -35,7 +35,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * k) implies the feature n is categorical with k categories 0, * 1, 2, ... , k-1. It's important to note that features are * zero-indexed. - * @param maxMemory maximum memory in MB allocated to histogram aggregation. Default value is + * @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is * 128 MB. * */ @@ -47,4 +47,4 @@ class Strategy ( val maxBins: Int = 100, val quantileCalculationStrategy: QuantileStrategy = Sort, val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), - val maxMemory: Int = 128) extends Serializable + val maxMemoryInMB: Int = 128) extends Serializable From dad96523d740c2b7ced0f0d73ade66e528b64064 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Wed, 30 Apr 2014 21:59:55 -0700 Subject: [PATCH 09/72] removed unused imports --- .../main/scala/org/apache/spark/mllib/tree/DecisionTree.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 4af6a827946b..952f03f10e53 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -28,8 +28,6 @@ import org.apache.spark.mllib.tree.impurity.Impurity import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom -import org.apache.spark.util.Utils.memoryStringToMb -import org.apache.spark.mllib.linalg.{Vector, Vectors} /** * :: Experimental :: From cbd9f140fd8d43941c61acd6055636bad88b358d Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sat, 3 May 2014 09:16:42 -0700 Subject: [PATCH 10/72] modified scala.math to math --- .../spark/mllib/tree/DecisionTree.scala | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 952f03f10e53..a5a4e61049cc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -60,7 +60,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // depth of the decision tree val maxDepth = strategy.maxDepth // the max number of nodes possible given the depth of the tree - val maxNumNodes = scala.math.pow(2, maxDepth).toInt - 1 + val maxNumNodes = math.pow(2, maxDepth).toInt - 1 // Initialize an array to hold filters applied to points for each node. val filters = new Array[List[Filter]](maxNumNodes) // The filter at the top node is an empty list. @@ -85,11 +85,11 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo } logDebug("numElementsPerNode = " + numElementsPerNode) val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array - val maxNumberOfNodesPerGroup = scala.math.max(maxMemoryUsage / arraySizePerNode, 1) + val maxNumberOfNodesPerGroup = math.max(maxMemoryUsage / arraySizePerNode, 1) logDebug("maxNumberOfNodesPerGroup = " + maxNumberOfNodesPerGroup) // nodes at a level is 2^(level-1). level is zero indexed. - val maxLevelForSingleGroup = scala.math.max( - (scala.math.log(maxNumberOfNodesPerGroup) / scala.math.log(2)).floor.toInt - 1, 0) + val maxLevelForSingleGroup = math.max( + (math.log(maxNumberOfNodesPerGroup) / math.log(2)).floor.toInt - 1, 0) logDebug("max level for single group = " + maxLevelForSingleGroup) /* @@ -120,7 +120,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo filters) logDebug("final best split = " + nodeSplitStats._1) } - require(scala.math.pow(2, level) == splitsStatsForLevel.length) + require(math.pow(2, level) == splitsStatsForLevel.length) // Check whether all the nodes at the current level at leaves. val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) logDebug("all leaf = " + allLeaf) @@ -153,7 +153,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo nodes: Array[Node]): Unit = { val split = nodeSplitStats._1 val stats = nodeSplitStats._2 - val nodeIndex = scala.math.pow(2, level).toInt - 1 + index + val nodeIndex = math.pow(2, level).toInt - 1 + index val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth - 1) val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats)) logDebug("Node = " + node) @@ -174,7 +174,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo var i = 0 while (i <= 1) { // Calculate the index of the node from the node level and the index at the current level. - val nodeIndex = scala.math.pow(2, level + 1).toInt - 1 + 2 * index + i + val nodeIndex = math.pow(2, level + 1).toInt - 1 + 2 * index + i if (level < maxDepth - 1) { val impurity = if (i == 0) { nodeSplitStats._2.leftImpurity @@ -300,7 +300,7 @@ object DecisionTree extends Serializable with Logging { maxLevelForSingleGroup: Int): Array[(Split, InformationGainStats)] = { // split into groups to avoid memory overflow during aggregation if (level > maxLevelForSingleGroup) { - val numGroups = scala.math.pow(2, (level - maxLevelForSingleGroup)).toInt + val numGroups = math.pow(2, (level - maxLevelForSingleGroup)).toInt logDebug("numGroups = " + numGroups) var groupIndex = 0 var bestSplits = new Array[(Split, InformationGainStats)](0) @@ -366,7 +366,7 @@ object DecisionTree extends Serializable with Logging { */ // common calculations for multiple nested methods - val numNodes = scala.math.pow(2, level).toInt / numGroups + val numNodes = math.pow(2, level).toInt / numGroups logDebug("numNodes = " + numNodes) // Find the number of features by looking at the first sample. val numFeatures = input.first().features.size @@ -382,7 +382,7 @@ object DecisionTree extends Serializable with Logging { if (level == 0) { List[Filter]() } else { - val nodeFilterIndex = scala.math.pow(2, level).toInt - 1 + nodeIndex + groupShift + val nodeFilterIndex = math.pow(2, level).toInt - 1 + nodeIndex + groupShift filters(nodeFilterIndex) } } @@ -951,7 +951,7 @@ object DecisionTree extends Serializable with Logging { // Iterating over all nodes at this level var node = 0 while (node < numNodes) { - val nodeImpurityIndex = scala.math.pow(2, level).toInt - 1 + node + groupShift + val nodeImpurityIndex = math.pow(2, level).toInt - 1 + node + groupShift val binsForNode: Array[Double] = getBinDataForNode(node) logDebug("nodeImpurityIndex = " + nodeImpurityIndex) val parentNodeImpurity = parentImpurities(nodeImpurityIndex) From 5e822020ce50c6e1bdbdbb3d94d5cbc4c715731e Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 5 May 2014 23:34:58 -0700 Subject: [PATCH 11/72] added documentation, fixed off by 1 error in max level calculation --- .../apache/spark/mllib/tree/DecisionTree.scala | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index a5a4e61049cc..6c99f82f687e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -76,10 +76,10 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Max memory usage for aggregates val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024 - logDebug("max memory usage for aggregates = " + maxMemoryUsage) + logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") val numElementsPerNode = { strategy.algo match { - case Classification => 2 * numBins * numFeatures + case Classification => 2 * numBins * numFeatures case Regression => 3 * numBins * numFeatures } } @@ -87,9 +87,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array val maxNumberOfNodesPerGroup = math.max(maxMemoryUsage / arraySizePerNode, 1) logDebug("maxNumberOfNodesPerGroup = " + maxNumberOfNodesPerGroup) - // nodes at a level is 2^(level-1). level is zero indexed. + // nodes at a level is 2^level. level is zero indexed. val maxLevelForSingleGroup = math.max( - (math.log(maxNumberOfNodesPerGroup) / math.log(2)).floor.toInt - 1, 0) + (math.log(maxNumberOfNodesPerGroup) / math.log(2)).floor.toInt, 0) logDebug("max level for single group = " + maxLevelForSingleGroup) /* @@ -299,11 +299,16 @@ object DecisionTree extends Serializable with Logging { bins: Array[Array[Bin]], maxLevelForSingleGroup: Int): Array[(Split, InformationGainStats)] = { // split into groups to avoid memory overflow during aggregation - if (level > maxLevelForSingleGroup) { + if (level > maxLevelForSingleGroup) { + // When information for all nodes at a given level cannot be stored in memory, + // the nodes are divided into multiple groups at each level with the number of groups + // increasing exponentially per level. For example, if maxLevelForSingleGroup is 10, + // numGroups is equal to 2 at level 11 and 4 at level 12, respectively. val numGroups = math.pow(2, (level - maxLevelForSingleGroup)).toInt logDebug("numGroups = " + numGroups) - var groupIndex = 0 var bestSplits = new Array[(Split, InformationGainStats)](0) + // Iterate over each group of nodes at a level. + var groupIndex = 0 while (groupIndex < numGroups) { val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, strategy, level, filters, splits, bins, numGroups, groupIndex) From 4731cda7b08fdcd365dd1b690ac04a26f6e85657 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 5 May 2014 23:44:39 -0700 Subject: [PATCH 12/72] formatting --- .../org/apache/spark/mllib/tree/DecisionTree.scala | 3 ++- .../apache/spark/mllib/tree/DecisionTreeSuite.scala | 12 ++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 6c99f82f687e..4d7ac51e2f01 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -275,7 +275,8 @@ object DecisionTree extends Serializable with Logging { private val InvalidBinIndex = -1 /** - * Returns an array of optimal splits for all nodes at a given level + * Returns an array of optimal splits for all nodes at a given level. Splits the tasks into + * multiple groups if the level-wise training tasks could lead to memory overflow. * * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data * for DecisionTree diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 4a0b399ca3dd..2155ed7b4a15 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -405,8 +405,8 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(splits(0).length === 99) assert(bins(0).length === 100) - val leftFilter = Filter(new Split(0,400,FeatureType.Continuous,List()),-1) - val rightFilter = Filter(new Split(0,400,FeatureType.Continuous,List()),1) + val leftFilter = Filter(new Split(0, 400, FeatureType.Continuous,List()), -1) + val rightFilter = Filter(new Split(0, 400, FeatureType.Continuous,List()) ,1) val filters = Array[List[Filter]](List(),List(leftFilter),List(rightFilter)) val parentImpurities = Array(0.5, 0.5, 0.5) @@ -444,7 +444,7 @@ object DecisionTreeSuite { def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) - for (i <- 0 until 1000){ + for (i <- 0 until 1000) { val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i)) arr(i) = lp } @@ -453,7 +453,7 @@ object DecisionTreeSuite { def generateOrderedLabeledPointsWithLabel1(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) - for (i <- 0 until 1000){ + for (i <- 0 until 1000) { val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 999.0 - i)) arr(i) = lp } @@ -462,7 +462,7 @@ object DecisionTreeSuite { def generateOrderedLabeledPoints(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) - for (i <- 0 until 1000){ + for (i <- 0 until 1000) { if (i < 600){ val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i)) arr(i) = lp @@ -476,7 +476,7 @@ object DecisionTreeSuite { def generateCategoricalDataPoints(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) - for (i <- 0 until 1000){ + for (i <- 0 until 1000) { if (i < 600){ arr(i) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)) } else { From 5eca9e4fbd0e27e335d5cea0bf26b1a436be0457 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 5 May 2014 23:47:14 -0700 Subject: [PATCH 13/72] grammar --- .../main/scala/org/apache/spark/mllib/tree/DecisionTree.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 4d7ac51e2f01..1f99f28e991f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -275,8 +275,8 @@ object DecisionTree extends Serializable with Logging { private val InvalidBinIndex = -1 /** - * Returns an array of optimal splits for all nodes at a given level. Splits the tasks into - * multiple groups if the level-wise training tasks could lead to memory overflow. + * Returns an array of optimal splits for all nodes at a given level. Splits the task into + * multiple groups if the level-wise training task could lead to memory overflow. * * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data * for DecisionTree From 8053fed22249bc788ba988489caa22f732b6416d Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 5 May 2014 23:48:02 -0700 Subject: [PATCH 14/72] more formatting --- .../org/apache/spark/mllib/tree/DecisionTreeSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 2155ed7b4a15..51802706d2fc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -405,9 +405,9 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { assert(splits(0).length === 99) assert(bins(0).length === 100) - val leftFilter = Filter(new Split(0, 400, FeatureType.Continuous,List()), -1) - val rightFilter = Filter(new Split(0, 400, FeatureType.Continuous,List()) ,1) - val filters = Array[List[Filter]](List(),List(leftFilter),List(rightFilter)) + val leftFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()), -1) + val rightFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()) ,1) + val filters = Array[List[Filter]](List(),List(leftFilter), List(rightFilter)) val parentImpurities = Array(0.5, 0.5, 0.5) // Single group second level tree construction. From 426bb285f16c816b19e5c25518024ae4d2141c1a Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Tue, 6 May 2014 00:16:02 -0700 Subject: [PATCH 15/72] programming guide blurb --- docs/mllib-decision-tree.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index 6667911a6aba..a2a2999a00e3 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -93,6 +93,10 @@ The recursive tree construction is stopped at a node when one of the two conditi 1. The node depth is equal to the `maxDepth` training parammeter 2. No split candidate leads to an information gain at the node. +### Max memory requirements + +For faster processing, the decision tree algorithm performs simultaneous histogram computations for all nodes at each level of the tree. This could lead to high memory requirements at deeper levels of the tree leading to memory overflow errors. To alleviate this problem, a 'maxMemoryInMB' training parameter is provided which specifies the maximum amount of memory at the workers (twice as much at the master) to be allocated to the histogram computation. The default value is conservatively chosen to be 128 MB to allow the decision algorithm to work in most scenarios. Once the memory requirements for a level-wise computation crosses the `maxMemoryInMB` threshold, the node training tasks at each subsequent level is split into smaller tasks. + ### Practical limitations 1. The implemented algorithm reads both sparse and dense data. However, it is not optimized for From b27ad2c20edb8a7bf0c0edd5d82a6a683b5d9ea2 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Tue, 6 May 2014 00:19:10 -0700 Subject: [PATCH 16/72] formatting --- .../apache/spark/mllib/tree/configuration/Strategy.scala | 2 +- .../org/apache/spark/mllib/tree/DecisionTreeSuite.scala | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index eeec2f1621cd..1b505fd76eb7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -36,7 +36,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * 1, 2, ... , k-1. It's important to note that features are * zero-indexed. * @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is - * 128 MB. + * 128 MB. * */ @Experimental diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 51802706d2fc..bc3b1a3fbe95 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -407,7 +407,7 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { val leftFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()), -1) val rightFilter = Filter(new Split(0, 400, FeatureType.Continuous, List()) ,1) - val filters = Array[List[Filter]](List(),List(leftFilter), List(rightFilter)) + val filters = Array[List[Filter]](List(), List(leftFilter), List(rightFilter)) val parentImpurities = Array(0.5, 0.5, 0.5) // Single group second level tree construction. @@ -463,7 +463,7 @@ object DecisionTreeSuite { def generateOrderedLabeledPoints(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) for (i <- 0 until 1000) { - if (i < 600){ + if (i < 600) { val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i)) arr(i) = lp } else { @@ -477,7 +477,7 @@ object DecisionTreeSuite { def generateCategoricalDataPoints(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) for (i <- 0 until 1000) { - if (i < 600){ + if (i < 600) { arr(i) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)) } else { arr(i) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0)) From ce004a1ab63405e0a5d0bc892a48b1c96c4d6605 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Tue, 6 May 2014 10:29:04 -0700 Subject: [PATCH 17/72] minor formatting --- .../main/scala/org/apache/spark/mllib/tree/DecisionTree.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 1f99f28e991f..c3cbe2c63ab0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -77,12 +77,12 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Max memory usage for aggregates val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024 logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") - val numElementsPerNode = { + val numElementsPerNode = strategy.algo match { case Classification => 2 * numBins * numFeatures case Regression => 3 * numBins * numFeatures } - } + logDebug("numElementsPerNode = " + numElementsPerNode) val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array val maxNumberOfNodesPerGroup = math.max(maxMemoryUsage / arraySizePerNode, 1) From 7fc95457ec66023ddf14e0ef3e1e18cbf828a4db Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Wed, 7 May 2014 10:47:27 -0700 Subject: [PATCH 18/72] added docs --- .../main/scala/org/apache/spark/mllib/tree/DecisionTree.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index c3cbe2c63ab0..0fe30a3e7040 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -334,6 +334,8 @@ object DecisionTree extends Serializable with Logging { * @param filters Filters for all nodes at a given level * @param splits possible splits for all features * @param bins possible bins for all features + * @param numGroups total number of node groups at the current level. Default value is set to 1. + * @param groupIndex index of the node group being processed. Default value is set to 0. * @return array of splits with best splits for all nodes at a given level. */ private def findBestSplitsPerGroup( From a1a6e09d7858d82a4b91d40dfd3aeb83f4da2a06 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Wed, 30 Apr 2014 21:57:42 -0700 Subject: [PATCH 19/72] added weighted point class --- .../mllib/point/WeightedLabeledPoint.scala | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/point/WeightedLabeledPoint.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/point/WeightedLabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/point/WeightedLabeledPoint.scala new file mode 100644 index 000000000000..f7effcf182db --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/point/WeightedLabeledPoint.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.point + +import org.apache.spark.mllib.linalg.Vector + +/** + * Class that represents the features and labels of a data point. + * + * @param label Label for this data point. + * @param features List of features for this data point. + */ +case class WeightedLabeledPoint(label: Double, features: Vector, weight:Double = 1) { + override def toString: String = { + "LabeledPoint(%s, %s, %s)".format(label, features, weight) + } +} \ No newline at end of file From 14aea48d10eca2727a1f79d3f65e508412c911ad Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Wed, 30 Apr 2014 22:15:41 -0700 Subject: [PATCH 20/72] changing instance format to weighted labeled point --- .../spark/mllib/tree/DecisionTree.scala | 26 +++++++++++-------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 0fe30a3e7040..486a1bb16af9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -28,6 +28,7 @@ import org.apache.spark.mllib.tree.impurity.Impurity import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom +import org.apache.spark.mllib.point.WeightedLabeledPoint /** * :: Experimental :: @@ -47,13 +48,16 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo */ def train(input: RDD[LabeledPoint]): DecisionTreeModel = { + // Converting from standard instance format to weighted input format for tree training + val weightedInput = input.map(x => WeightedLabeledPoint(x.label,x.features)) + // Cache input RDD for speedup during multiple passes. - input.cache() + weightedInput.cache() logDebug("algo = " + strategy.algo) // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. - val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(weightedInput, strategy) val numBins = bins(0).length logDebug("numBins = " + numBins) @@ -70,7 +74,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // dummy value for top node (updated during first split calculation) val nodes = new Array[Node](maxNumNodes) // num features - val numFeatures = input.take(1)(0).features.size + val numFeatures = weightedInput.take(1)(0).features.size // Calculate level for single group construction @@ -109,8 +113,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo logDebug("#####################################") // Find best split for all nodes at a level. - val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, - level, filters, splits, bins, maxLevelForSingleGroup) + val splitsStatsForLevel = DecisionTree.findBestSplits(weightedInput, parentImpurities, + strategy, level, filters, splits, bins, maxLevelForSingleGroup) for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { // Extract info for nodes at the current level. @@ -291,7 +295,7 @@ object DecisionTree extends Serializable with Logging { * @return array of splits with best splits for all nodes at a given level. */ protected[tree] def findBestSplits( - input: RDD[LabeledPoint], + input: RDD[WeightedLabeledPoint], parentImpurities: Array[Double], strategy: Strategy, level: Int, @@ -339,7 +343,7 @@ object DecisionTree extends Serializable with Logging { * @return array of splits with best splits for all nodes at a given level. */ private def findBestSplitsPerGroup( - input: RDD[LabeledPoint], + input: RDD[WeightedLabeledPoint], parentImpurities: Array[Double], strategy: Strategy, level: Int, @@ -399,7 +403,7 @@ object DecisionTree extends Serializable with Logging { * Find whether the sample is valid input for the current node, i.e., whether it passes through * all the filters for the current node. */ - def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = { + def isSampleValid(parentFilters: List[Filter], labeledPoint: WeightedLabeledPoint): Boolean = { // leaf if ((level > 0) & (parentFilters.length == 0)) { return false @@ -438,7 +442,7 @@ object DecisionTree extends Serializable with Logging { */ def findBin( featureIndex: Int, - labeledPoint: LabeledPoint, + labeledPoint: WeightedLabeledPoint, isFeatureContinuous: Boolean): Int = { val binForFeatures = bins(featureIndex) val feature = labeledPoint.features(featureIndex) @@ -509,7 +513,7 @@ object DecisionTree extends Serializable with Logging { * where b_ij is an integer between 0 and numBins - 1. * Invalid sample is denoted by noting bin for feature 1 as -1. */ - def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = { + def findBinsForLevel(labeledPoint: WeightedLabeledPoint): Array[Double] = { // Calculate bin index and label per feature per node. val arr = new Array[Double](1 + (numFeatures * numNodes)) arr(0) = labeledPoint.label @@ -982,7 +986,7 @@ object DecisionTree extends Serializable with Logging { * .spark.mllib.tree.model.Bin] of size (numFeatures, numSplits1) */ protected[tree] def findSplitsBins( - input: RDD[LabeledPoint], + input: RDD[WeightedLabeledPoint], strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = { val count = input.count() From 455bea92849f9c8f180cf6cbff8989b368d5b9ab Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Wed, 30 Apr 2014 22:21:32 -0700 Subject: [PATCH 21/72] fixed tests --- .../org/apache/spark/mllib/tree/DecisionTreeSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 35e92d71dc63..92e0f9aeeb52 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.tree import org.scalatest.FunSuite -import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.point.WeightedLabeledPoint import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} import org.apache.spark.mllib.tree.model.Filter import org.apache.spark.mllib.tree.model.Split @@ -455,7 +455,7 @@ object DecisionTreeSuite { val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i)) arr(i) = lp } else { - val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 1000.0 - i)) + val lp = new WeightedLabeledPoint(1.0, Vectors.dense(i.toDouble, 1000.0 - i)) arr(i) = lp } } @@ -468,7 +468,7 @@ object DecisionTreeSuite { if (i < 600) { arr(i) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)) } else { - arr(i) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0)) + arr(i) = new WeightedLabeledPoint(0.0, Vectors.dense(1.0, 0.0)) } } arr From 46f909c01419603d4526685dfcb2b713c8e3c979 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 4 May 2014 16:30:04 -0700 Subject: [PATCH 22/72] todo for multiclass support --- .../scala/org/apache/spark/mllib/tree/DecisionTree.scala | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 486a1bb16af9..7c18834ef346 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -569,6 +569,7 @@ object DecisionTree extends Serializable with Logging { // Update the left or right count for one bin. val aggShift = 2 * numBins * numFeatures * nodeIndex val aggIndex = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2 + // TODO: Multiclass modification here label match { case 0.0 => agg(aggIndex) = agg(aggIndex) + 1 case 1.0 => agg(aggIndex + 1) = agg(aggIndex + 1) + 1 @@ -679,6 +680,7 @@ object DecisionTree extends Serializable with Logging { topImpurity: Double): InformationGainStats = { strategy.algo match { case Classification => + // TODO: Modify here val left0Count = leftNodeAgg(featureIndex)(2 * splitIndex) val left1Count = leftNodeAgg(featureIndex)(2 * splitIndex + 1) val leftCount = left0Count + left1Count @@ -779,6 +781,7 @@ object DecisionTree extends Serializable with Logging { binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = { strategy.algo match { case Classification => + // TODO: Multiclass modification here // Initialize left and right split aggregates. val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1)) val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1)) @@ -904,6 +907,8 @@ object DecisionTree extends Serializable with Logging { binData: Array[Double], nodeImpurity: Double): (Split, InformationGainStats) = { + // TODO: Multiclass modification here + logDebug("node impurity = " + nodeImpurity) // Extract left right node aggregates. @@ -948,6 +953,7 @@ object DecisionTree extends Serializable with Logging { def getBinDataForNode(node: Int): Array[Double] = { strategy.algo match { case Classification => + // TODO: Multiclass modification here val shift = 2 * node * numBins * numFeatures val binsForNode = binAggregates.slice(shift, shift + 2 * numBins * numFeatures) binsForNode @@ -997,6 +1003,8 @@ object DecisionTree extends Serializable with Logging { val numBins = if (maxBins <= count) maxBins else count.toInt logDebug("numBins = " + numBins) + // TODO: Multiclass modification here + /* * TODO: Add a require statement ensuring #bins is always greater than the categories. * It's a limitation of the current implementation but a reasonable trade-off since features @@ -1041,6 +1049,7 @@ object DecisionTree extends Serializable with Logging { splits(featureIndex)(index) = split } } else { + // TODO: Multiclass modification here val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex) require(maxFeatureValue < numBins, "number of categories should be less than number " + "of bins") From 4d5f70c4688c1183b754f2133a4d5a11d862070a Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 5 May 2014 22:52:08 -0700 Subject: [PATCH 23/72] added multiclass support for find splits bins --- .../spark/mllib/tree/DecisionTree.scala | 117 ++++++++++++------ .../mllib/tree/configuration/Strategy.scala | 10 +- 2 files changed, 90 insertions(+), 37 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 7c18834ef346..1c2f4cd70474 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -1006,15 +1006,20 @@ object DecisionTree extends Serializable with Logging { // TODO: Multiclass modification here /* - * TODO: Add a require statement ensuring #bins is always greater than the categories. + * Ensure #bins is always greater than the categories. For multiclass classification, + * #bins should be greater than 2^(maxCategories - 1) - 1. * It's a limitation of the current implementation but a reasonable trade-off since features * with large number of categories get favored over continuous features. */ if (strategy.categoricalFeaturesInfo.size > 0) { val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2 require(numBins >= maxCategoriesForFeatures) + if (strategy.isMultiClassification) { + require(numBins > math.pow(2, maxCategoriesForFeatures.toInt) - 1) + } } + // Calculate the number of sample for approximate quantile calculation. val requiredSamples = numBins*numBins val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0 @@ -1048,49 +1053,69 @@ object DecisionTree extends Serializable with Logging { val split = new Split(featureIndex, featureSamples(sampleIndex), Continuous, List()) splits(featureIndex)(index) = split } - } else { - // TODO: Multiclass modification here + } else { // Categorical feature val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex) - require(maxFeatureValue < numBins, "number of categories should be less than number " + - "of bins") - - // For categorical variables, each bin is a category. The bins are sorted and they - // are ordered by calculating the centroid of their corresponding labels. - val centroidForCategories = - sampledInput.map(lp => (lp.features(featureIndex),lp.label)) - .groupBy(_._1) - .mapValues(x => x.map(_._2).sum / x.map(_._1).length) - - // Check for missing categorical variables and putting them last in the sorted list. - val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]() - for (i <- 0 until maxFeatureValue) { - if (centroidForCategories.contains(i)) { - fullCentroidForCategories(i) = centroidForCategories(i) - } else { - fullCentroidForCategories(i) = Double.MaxValue - } - } - // bins sorted by centroids - val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2) - - logDebug("centriod for categorical variable = " + categoriesSortedByCentroid) - - var categoriesForSplit = List[Double]() - categoriesSortedByCentroid.iterator.zipWithIndex.foreach { - case ((key, value), index) => - categoriesForSplit = key :: categoriesForSplit - splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue, Categorical, - categoriesForSplit) + // Use different bin/split calculation strategy for multiclass classification + if (strategy.isMultiClassification) { + // Iterate from 1 to 2^maxFeatureValue leading to 2^(maxFeatureValue- 1) - 1 + // combinations. + var index = 1 + while (index < math.pow(2.0, maxFeatureValue).toInt) { + val categories: List[Double] = extractMultiClassCategories(index, maxFeatureValue) + splits(featureIndex)(index) + = new Split(featureIndex, Double.MinValue, Categorical, categories) bins(featureIndex)(index) = { if (index == 0) { new Bin(new DummyCategoricalSplit(featureIndex, Categorical), - splits(featureIndex)(0), Categorical, key) + splits(featureIndex)(0), Categorical, index) } else { - new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), - Categorical, key) + new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), Categorical, + Double.MinValue) } } + index += 1 + } + } else { // regression or binary classification + + // For categorical variables, each bin is a category. The bins are sorted and they + // are ordered by calculating the centroid of their corresponding labels. + val centroidForCategories = + sampledInput.map(lp => (lp.features(featureIndex),lp.label)) + .groupBy(_._1) + .mapValues(x => x.map(_._2).sum / x.map(_._1).length) + + // Check for missing categorical variables and putting them last in the sorted list. + val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]() + for (i <- 0 until maxFeatureValue) { + if (centroidForCategories.contains(i)) { + fullCentroidForCategories(i) = centroidForCategories(i) + } else { + fullCentroidForCategories(i) = Double.MaxValue + } + } + + // bins sorted by centroids + val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2) + + logDebug("centriod for categorical variable = " + categoriesSortedByCentroid) + + var categoriesForSplit = List[Double]() + categoriesSortedByCentroid.iterator.zipWithIndex.foreach { + case ((key, value), index) => + categoriesForSplit = key :: categoriesForSplit + splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue, Categorical, + categoriesForSplit) + bins(featureIndex)(index) = { + if (index == 0) { + new Bin(new DummyCategoricalSplit(featureIndex, Categorical), + splits(featureIndex)(0), Categorical, key) + } else { + new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), + Categorical, key) + } + } + } } } featureIndex += 1 @@ -1120,4 +1145,24 @@ object DecisionTree extends Serializable with Logging { throw new UnsupportedOperationException("approximate histogram not supported yet.") } } + + /** + * Nested method to extract list of eligible categories given an index + */ + private def extractMultiClassCategories(i: Int, maxFeatureValue: Double): List[Double] = { + // TODO: Test this + var categories = List[Double]() + var j = 0 + while (j < maxFeatureValue) { + var copy = i + if (copy % 2 != 0) { + // updating the list of categories. + categories = j.toDouble :: categories + } + copy = copy >> 1 + j += 1 + } + categories + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 1b505fd76eb7..3aa2d5382cb8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -37,6 +37,8 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * zero-indexed. * @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is * 128 MB. + * @param numClassesForClassification number of classes for classification. Default value is 2 + * leads to binary classification * */ @Experimental @@ -47,4 +49,10 @@ class Strategy ( val maxBins: Int = 100, val quantileCalculationStrategy: QuantileStrategy = Sort, val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), - val maxMemoryInMB: Int = 128) extends Serializable + val maxMemoryInMB: Int = 128, + val numClassesForClassification: Int = 2) extends Serializable { + + require(numClassesForClassification >= 2) + val isMultiClassification = numClassesForClassification > 2 + +} From 3f85a17d4c36bebb4831767dcd364fb12cf44873 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Tue, 6 May 2014 18:01:37 -0700 Subject: [PATCH 24/72] tests for multiclass classification --- .../spark/mllib/tree/DecisionTree.scala | 30 ++-- .../spark/mllib/tree/DecisionTreeSuite.scala | 138 ++++++++++++++++-- 2 files changed, 143 insertions(+), 25 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 1c2f4cd70474..6c7097cdb5a9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -1003,8 +1003,6 @@ object DecisionTree extends Serializable with Logging { val numBins = if (maxBins <= count) maxBins else count.toInt logDebug("numBins = " + numBins) - // TODO: Multiclass modification here - /* * Ensure #bins is always greater than the categories. For multiclass classification, * #bins should be greater than 2^(maxCategories - 1) - 1. @@ -1058,17 +1056,18 @@ object DecisionTree extends Serializable with Logging { // Use different bin/split calculation strategy for multiclass classification if (strategy.isMultiClassification) { - // Iterate from 1 to 2^maxFeatureValue leading to 2^(maxFeatureValue- 1) - 1 + // Iterate from 0 to 2^maxFeatureValue - 1 leading to 2^(maxFeatureValue- 1) - 1 // combinations. - var index = 1 - while (index < math.pow(2.0, maxFeatureValue).toInt) { - val categories: List[Double] = extractMultiClassCategories(index, maxFeatureValue) + var index = 0 + while (index < math.pow(2.0, maxFeatureValue).toInt - 1) { + val categories: List[Double] + = extractMultiClassCategories(index + 1, maxFeatureValue) splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue, Categorical, categories) bins(featureIndex)(index) = { if (index == 0) { new Bin(new DummyCategoricalSplit(featureIndex, Categorical), - splits(featureIndex)(0), Categorical, index) + splits(featureIndex)(0), Categorical, Double.MinValue) } else { new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), Categorical, Double.MinValue) @@ -1147,19 +1146,24 @@ object DecisionTree extends Serializable with Logging { } /** - * Nested method to extract list of eligible categories given an index + * Nested method to extract list of eligible categories given an index. It extracts the + * position of ones in a binary representation of the input. If binary + * representation of an number is 01101 (13), the output list should (3.0, 2.0, + * 0.0). The maxFeatureValue depict the number of rightmost digits that will be tested for ones. */ - private def extractMultiClassCategories(i: Int, maxFeatureValue: Double): List[Double] = { - // TODO: Test this + private[tree] def extractMultiClassCategories( + input: Int, + maxFeatureValue: Int): List[Double] = { var categories = List[Double]() var j = 0 + var bitShiftedInput = input while (j < maxFeatureValue) { - var copy = i - if (copy % 2 != 0) { + if (bitShiftedInput % 2 != 0) { // updating the list of categories. categories = j.toDouble :: categories } - copy = copy >> 1 + //Right shift by one + bitShiftedInput = bitShiftedInput >> 1 j += 1 } categories diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 92e0f9aeeb52..b5259d8e1882 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -231,6 +231,120 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(1)(3) === null) } + test("extract categories from a number for multiclass classification") { + val l = DecisionTree.extractMultiClassCategories(13, 10) + assert(l.length === 3) + assert(List(3.0, 2.0, 0.0).toSeq == l.toSeq) + } + + test("split and bin calculations for categorical variables wiht multiclass classification") { + val arr = DecisionTreeSuite.generateCategoricalDataPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy( + Classification, + Gini, + maxDepth = 3, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 2, 1-> 2), + numClassesForClassification = 3) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + + // Expecting 2^3 - 1 = 7 bins/splits + assert(splits(0)(0).feature === 0) + assert(splits(0)(0).threshold === Double.MinValue) + assert(splits(0)(0).featureType === Categorical) + assert(splits(0)(0).categories.length === 1) + assert(splits(0)(0).categories.contains(0.0)) + assert(splits(1)(0).feature === 1) + assert(splits(1)(0).threshold === Double.MinValue) + assert(splits(1)(0).featureType === Categorical) + assert(splits(1)(0).categories.length === 1) + assert(splits(1)(0).categories.contains(0.0)) + + assert(splits(0)(1).feature === 0) + assert(splits(0)(1).threshold === Double.MinValue) + assert(splits(0)(1).featureType === Categorical) + assert(splits(0)(1).categories.length === 1) + assert(splits(0)(1).categories.contains(1.0)) + assert(splits(1)(1).feature === 1) + assert(splits(1)(1).threshold === Double.MinValue) + assert(splits(1)(1).featureType === Categorical) + assert(splits(1)(1).categories.length === 1) + assert(splits(1)(1).categories.contains(1.0)) + + assert(splits(0)(2).feature === 0) + assert(splits(0)(2).threshold === Double.MinValue) + assert(splits(0)(2).featureType === Categorical) + assert(splits(0)(2).categories.length === 2) + assert(splits(0)(2).categories.contains(0.0)) + assert(splits(0)(2).categories.contains(1.0)) + assert(splits(1)(2).feature === 1) + assert(splits(1)(2).threshold === Double.MinValue) + assert(splits(1)(2).featureType === Categorical) + assert(splits(1)(2).categories.length === 2) + assert(splits(1)(2).categories.contains(0.0)) + assert(splits(1)(2).categories.contains(1.0)) + + assert(splits(0)(3) === null) + + + // Check bins. + + assert(bins(0)(0).category === Double.MinValue) + assert(bins(0)(0).lowSplit.categories.length === 0) + assert(bins(0)(0).highSplit.categories.length === 1) + assert(bins(0)(0).highSplit.categories.contains(0.0)) + assert(bins(1)(0).category === Double.MinValue) + assert(bins(1)(0).lowSplit.categories.length === 0) + assert(bins(1)(0).highSplit.categories.length === 1) + assert(bins(1)(0).highSplit.categories.contains(0.0)) + + assert(bins(0)(1).category === Double.MinValue) + assert(bins(0)(1).lowSplit.categories.length === 1) + assert(bins(0)(1).lowSplit.categories.contains(0.0)) + assert(bins(0)(1).highSplit.categories.length === 1) + assert(bins(0)(1).highSplit.categories.contains(1.0)) + assert(bins(1)(1).category === Double.MinValue) + assert(bins(1)(1).lowSplit.categories.length === 1) + assert(bins(1)(1).lowSplit.categories.contains(0.0)) + assert(bins(1)(1).highSplit.categories.length === 1) + assert(bins(1)(1).highSplit.categories.contains(1.0)) + + assert(bins(0)(2).category === Double.MinValue) + assert(bins(0)(2).lowSplit.categories.length === 1) + assert(bins(0)(2).lowSplit.categories.contains(1.0)) + assert(bins(0)(2).highSplit.categories.length === 2) + assert(bins(0)(2).highSplit.categories.contains(1.0)) + assert(bins(0)(2).highSplit.categories.contains(0.0)) + assert(bins(1)(2).category === Double.MinValue) + assert(bins(1)(2).lowSplit.categories.length === 1) + assert(bins(1)(2).lowSplit.categories.contains(1.0)) + assert(bins(1)(2).highSplit.categories.length === 2) + assert(bins(1)(2).highSplit.categories.contains(1.0)) + assert(bins(1)(2).highSplit.categories.contains(0.0)) + + assert(bins(0)(3) === null) + assert(bins(1)(3) === null) + + } + + test("split and bin calculations for categorical variables with no sample for one category " + + "for multiclass classification") { + val arr = DecisionTreeSuite.generateCategoricalDataPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy( + Classification, + Gini, + maxDepth = 3, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 3, 1-> 3), + numClassesForClassification = 3) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + + } + test("classification stump with all categorical variables") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) @@ -430,29 +544,29 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { object DecisionTreeSuite { - def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = { - val arr = new Array[LabeledPoint](1000) + def generateOrderedLabeledPointsWithLabel0(): Array[WeightedLabeledPoint] = { + val arr = new Array[WeightedLabeledPoint](1000) for (i <- 0 until 1000) { - val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i)) + val lp = new WeightedLabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i)) arr(i) = lp } arr } - def generateOrderedLabeledPointsWithLabel1(): Array[LabeledPoint] = { - val arr = new Array[LabeledPoint](1000) + def generateOrderedLabeledPointsWithLabel1(): Array[WeightedLabeledPoint] = { + val arr = new Array[WeightedLabeledPoint](1000) for (i <- 0 until 1000) { - val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 999.0 - i)) + val lp = new WeightedLabeledPoint(1.0, Vectors.dense(i.toDouble, 999.0 - i)) arr(i) = lp } arr } - def generateOrderedLabeledPoints(): Array[LabeledPoint] = { - val arr = new Array[LabeledPoint](1000) + def generateOrderedLabeledPoints(): Array[WeightedLabeledPoint] = { + val arr = new Array[WeightedLabeledPoint](1000) for (i <- 0 until 1000) { if (i < 600) { - val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i)) + val lp = new WeightedLabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i)) arr(i) = lp } else { val lp = new WeightedLabeledPoint(1.0, Vectors.dense(i.toDouble, 1000.0 - i)) @@ -462,11 +576,11 @@ object DecisionTreeSuite { arr } - def generateCategoricalDataPoints(): Array[LabeledPoint] = { - val arr = new Array[LabeledPoint](1000) + def generateCategoricalDataPoints(): Array[WeightedLabeledPoint] = { + val arr = new Array[WeightedLabeledPoint](1000) for (i <- 0 until 1000) { if (i < 600) { - arr(i) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)) + arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(0.0, 1.0)) } else { arr(i) = new WeightedLabeledPoint(0.0, Vectors.dense(1.0, 0.0)) } From 46e06ee0ceb223aee50fa811a35d25090a5c4d42 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Tue, 6 May 2014 18:05:58 -0700 Subject: [PATCH 25/72] minor mods --- .../scala/org/apache/spark/mllib/tree/DecisionTree.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 6c7097cdb5a9..49b821d58907 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -562,6 +562,7 @@ object DecisionTree extends Serializable with Logging { val label = arr(0) // Iterate over all features. var featureIndex = 0 + // TODO: Multiclass modification here while (featureIndex < numFeatures) { // Find the bin index for this feature. val arrShift = 1 + numFeatures * nodeIndex @@ -569,10 +570,8 @@ object DecisionTree extends Serializable with Logging { // Update the left or right count for one bin. val aggShift = 2 * numBins * numFeatures * nodeIndex val aggIndex = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2 - // TODO: Multiclass modification here label match { - case 0.0 => agg(aggIndex) = agg(aggIndex) + 1 - case 1.0 => agg(aggIndex + 1) = agg(aggIndex + 1) + 1 + case n: Double => agg(aggIndex) = agg(aggIndex + n.toInt) + 1 } featureIndex += 1 } From 6c7af2206e6bd16e8bcc4feb4626bfccb5837c55 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Tue, 6 May 2014 23:09:46 -0700 Subject: [PATCH 26/72] prepared for multiclass without breaking binary classification --- .../spark/mllib/tree/DecisionTree.scala | 189 ++++++++++-------- 1 file changed, 107 insertions(+), 82 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 49b821d58907..0ca4366ae6e8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -385,6 +385,8 @@ object DecisionTree extends Serializable with Logging { logDebug("numFeatures = " + numFeatures) val numBins = bins(0).length logDebug("numBins = " + numBins) + val numClasses = strategy.numClassesForClassification + logDebug("numClasses = " + numClasses) // shift when more than one group is used at deep tree level val groupShift = numNodes * groupIndex @@ -545,10 +547,10 @@ object DecisionTree extends Serializable with Logging { * incremented based upon whether the feature is classified as 0 or 1. * * @param agg Array[Double] storing aggregate calculation of size - * 2 * numSplits * numFeatures*numNodes for classification + * numClasses * numSplits * numFeatures*numNodes for classification * @param arr Array[Double] of size 1 + (numFeatures * numNodes) * @return Array[Double] storing aggregate calculation of size - * 2 * numSplits * numFeatures * numNodes for classification + * numClasses * numSplits * numFeatures * numNodes for classification */ def classificationBinSeqOp(arr: Array[Double], agg: Array[Double]) { // Iterate over all nodes. @@ -562,16 +564,16 @@ object DecisionTree extends Serializable with Logging { val label = arr(0) // Iterate over all features. var featureIndex = 0 - // TODO: Multiclass modification here while (featureIndex < numFeatures) { // Find the bin index for this feature. val arrShift = 1 + numFeatures * nodeIndex val arrIndex = arrShift + featureIndex // Update the left or right count for one bin. - val aggShift = 2 * numBins * numFeatures * nodeIndex - val aggIndex = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2 - label match { - case n: Double => agg(aggIndex) = agg(aggIndex + n.toInt) + 1 + val aggShift = numClasses * numBins * numFeatures * nodeIndex + val aggIndex + = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses + label.toInt match { + case n: Int => agg(aggIndex + n) = agg(aggIndex + n) + 1 } featureIndex += 1 } @@ -632,7 +634,7 @@ object DecisionTree extends Serializable with Logging { // Calculate bin aggregate length for classification or regression. val binAggregateLength = strategy.algo match { - case Classification => 2 * numBins * numFeatures * numNodes + case Classification => numClasses * numBins * numFeatures * numNodes case Regression => 3 * numBins * numFeatures * numNodes } logDebug("binAggregateLength = " + binAggregateLength) @@ -672,20 +674,20 @@ object DecisionTree extends Serializable with Logging { * @return information gain and statistics for all splits */ def calculateGainForSplit( - leftNodeAgg: Array[Array[Double]], + leftNodeAgg: Array[Array[Array[Double]]], featureIndex: Int, splitIndex: Int, - rightNodeAgg: Array[Array[Double]], + rightNodeAgg: Array[Array[Array[Double]]], topImpurity: Double): InformationGainStats = { strategy.algo match { case Classification => // TODO: Modify here - val left0Count = leftNodeAgg(featureIndex)(2 * splitIndex) - val left1Count = leftNodeAgg(featureIndex)(2 * splitIndex + 1) + val left0Count = leftNodeAgg(featureIndex)(splitIndex)(0) + val left1Count = leftNodeAgg(featureIndex)(splitIndex)(1) val leftCount = left0Count + left1Count - val right0Count = rightNodeAgg(featureIndex)(2 * splitIndex) - val right1Count = rightNodeAgg(featureIndex)(2 * splitIndex + 1) + val right0Count = rightNodeAgg(featureIndex)(splitIndex)(0) + val right1Count = rightNodeAgg(featureIndex)(splitIndex)(1) val rightCount = right0Count + right1Count val impurity = { @@ -722,13 +724,13 @@ object DecisionTree extends Serializable with Logging { new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) case Regression => - val leftCount = leftNodeAgg(featureIndex)(3 * splitIndex) - val leftSum = leftNodeAgg(featureIndex)(3 * splitIndex + 1) - val leftSumSquares = leftNodeAgg(featureIndex)(3 * splitIndex + 2) + val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0) + val leftSum = leftNodeAgg(featureIndex)(splitIndex)(1) + val leftSumSquares = leftNodeAgg(featureIndex)(splitIndex)(2) - val rightCount = rightNodeAgg(featureIndex)(3 * splitIndex) - val rightSum = rightNodeAgg(featureIndex)(3 * splitIndex + 1) - val rightSumSquares = rightNodeAgg(featureIndex)(3 * splitIndex + 2) + val rightCount = rightNodeAgg(featureIndex)(splitIndex)(0) + val rightSum = rightNodeAgg(featureIndex)(splitIndex)(1) + val rightSumSquares = rightNodeAgg(featureIndex)(splitIndex)(2) val impurity = { if (level > 0) { @@ -777,73 +779,96 @@ object DecisionTree extends Serializable with Logging { * Array[Double]) where each array is of size(numFeature,2*(numSplits-1)) */ def extractLeftRightNodeAggregates( - binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = { + binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = { strategy.algo match { case Classification => // TODO: Multiclass modification here - // Initialize left and right split aggregates. - val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1)) - val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1)) - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - // shift for this featureIndex - val shift = 2 * featureIndex * numBins - - // left node aggregate for the lowest split - leftNodeAgg(featureIndex)(0) = binData(shift + 0) - leftNodeAgg(featureIndex)(1) = binData(shift + 1) - - // right node aggregate for the highest split - rightNodeAgg(featureIndex)(2 * (numBins - 2)) - = binData(shift + (2 * (numBins - 1))) - rightNodeAgg(featureIndex)(2 * (numBins - 2) + 1) - = binData(shift + (2 * (numBins - 1)) + 1) - - // Iterate over all splits. - var splitIndex = 1 - while (splitIndex < numBins - 1) { - // calculating left node aggregate for a split as a sum of left node aggregate of a - // lower split and the left bin aggregate of a bin where the split is a high split - leftNodeAgg(featureIndex)(2 * splitIndex) = binData(shift + 2 * splitIndex) + - leftNodeAgg(featureIndex)(2 * splitIndex - 2) - leftNodeAgg(featureIndex)(2 * splitIndex + 1) = binData(shift + 2 * splitIndex + 1) + - leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1) - // calculating right node aggregate for a split as a sum of right node aggregate of a - // higher split and the right bin aggregate of a bin where the split is a low split - rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) = - binData(shift + (2 *(numBins - 2 - splitIndex))) + - rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex)) - rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1) = - binData(shift + (2* (numBins - 2 - splitIndex) + 1)) + - rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1) - - splitIndex += 1 + // Initialize left and right split aggregates. + val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) + val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) + + if (strategy.isMultiClassification) { + var featureIndex = 0 + while (featureIndex < numFeatures){ + val numCategories = strategy.categoricalFeaturesInfo(featureIndex) + val maxSplits = math.pow(2, numCategories) - 1 + var i = 0 + // TODO: Add multiclass case here + while (i < maxSplits) { + var classIndex = 0 + while (classIndex < numClasses) { + // shift for this featureIndex + val shift = numClasses * featureIndex * numBins + + classIndex += 1 + } + i += 1 + } + featureIndex += 1 + } + } else { + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + // shift for this featureIndex + val shift = 2 * featureIndex * numBins + + // left node aggregate for the lowest split + leftNodeAgg(featureIndex)(0)(0) = binData(shift + 0) + leftNodeAgg(featureIndex)(0)(1) = binData(shift + 1) + + // right node aggregate for the highest split + rightNodeAgg(featureIndex)(numBins - 2)(0) + = binData(shift + (2 * (numBins - 1))) + rightNodeAgg(featureIndex)(numBins - 2)(1) + = binData(shift + (2 * (numBins - 1)) + 1) + + // Iterate over all splits. + var splitIndex = 1 + while (splitIndex < numBins - 1) { + // calculating left node aggregate for a split as a sum of left node aggregate of a + // lower split and the left bin aggregate of a bin where the split is a high split + leftNodeAgg(featureIndex)(splitIndex)(0) = binData(shift + 2 * splitIndex) + + leftNodeAgg(featureIndex)(splitIndex - 1)(0) + leftNodeAgg(featureIndex)(splitIndex)(1) = binData(shift + 2 * splitIndex + + 1) + leftNodeAgg(featureIndex)(splitIndex - 1)(1) + + // calculating right node aggregate for a split as a sum of right node aggregate of a + // higher split and the right bin aggregate of a bin where the split is a low split + rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(0) = + binData(shift + (2 *(numBins - 2 - splitIndex))) + + rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(0) + rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(1) = + binData(shift + (2* (numBins - 2 - splitIndex) + 1)) + + rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(1) + + splitIndex += 1 + } + featureIndex += 1 } - featureIndex += 1 } (leftNodeAgg, rightNodeAgg) case Regression => // Initialize left and right split aggregates. - val leftNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1)) - val rightNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1)) + val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3) + val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3) // Iterate over all features. var featureIndex = 0 while (featureIndex < numFeatures) { // shift for this featureIndex val shift = 3 * featureIndex * numBins // left node aggregate for the lowest split - leftNodeAgg(featureIndex)(0) = binData(shift + 0) - leftNodeAgg(featureIndex)(1) = binData(shift + 1) - leftNodeAgg(featureIndex)(2) = binData(shift + 2) + leftNodeAgg(featureIndex)(0)(0) = binData(shift + 0) + leftNodeAgg(featureIndex)(0)(1) = binData(shift + 1) + leftNodeAgg(featureIndex)(0)(2) = binData(shift + 2) // right node aggregate for the highest split - rightNodeAgg(featureIndex)(3 * (numBins - 2)) = + rightNodeAgg(featureIndex)(numBins - 2)(0) = binData(shift + (3 * (numBins - 1))) - rightNodeAgg(featureIndex)(3 * (numBins - 2) + 1) = + rightNodeAgg(featureIndex)(numBins - 2)(1) = binData(shift + (3 * (numBins - 1)) + 1) - rightNodeAgg(featureIndex)(3 * (numBins - 2) + 2) = + rightNodeAgg(featureIndex)(numBins - 2)(2) = binData(shift + (3 * (numBins - 1)) + 2) // Iterate over all splits. @@ -851,24 +876,24 @@ object DecisionTree extends Serializable with Logging { while (splitIndex < numBins - 1) { // calculating left node aggregate for a split as a sum of left node aggregate of a // lower split and the left bin aggregate of a bin where the split is a high split - leftNodeAgg(featureIndex)(3 * splitIndex) = binData(shift + 3 * splitIndex) + - leftNodeAgg(featureIndex)(3 * splitIndex - 3) - leftNodeAgg(featureIndex)(3 * splitIndex + 1) = binData(shift + 3 * splitIndex + 1) + - leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 1) - leftNodeAgg(featureIndex)(3 * splitIndex + 2) = binData(shift + 3 * splitIndex + 2) + - leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2) + leftNodeAgg(featureIndex)(splitIndex)(0) = binData(shift + 3 * splitIndex) + + leftNodeAgg(featureIndex)(splitIndex - 1)(0) + leftNodeAgg(featureIndex)(splitIndex)(1) = binData(shift + 3 * splitIndex + 1) + + leftNodeAgg(featureIndex)(splitIndex - 1)(1) + leftNodeAgg(featureIndex)(splitIndex)(2) = binData(shift + 3 * splitIndex + 2) + + leftNodeAgg(featureIndex)(splitIndex - 1)(2) // calculating right node aggregate for a split as a sum of right node aggregate of a // higher split and the right bin aggregate of a bin where the split is a low split - rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex)) = + rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(0) = binData(shift + (3 * (numBins - 2 - splitIndex))) + - rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex)) - rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1) = + rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(0) + rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(1) = binData(shift + (3 * (numBins - 2 - splitIndex) + 1)) + - rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1) - rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2) = + rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(1) + rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(2) = binData(shift + (3 * (numBins - 2 - splitIndex) + 2)) + - rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2) + rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(2) splitIndex += 1 } @@ -882,8 +907,8 @@ object DecisionTree extends Serializable with Logging { * Calculates information gain for all nodes splits. */ def calculateGainsForAllNodeSplits( - leftNodeAgg: Array[Array[Double]], - rightNodeAgg: Array[Array[Double]], + leftNodeAgg: Array[Array[Array[Double]]], + rightNodeAgg: Array[Array[Array[Double]]], nodeImpurity: Double): Array[Array[InformationGainStats]] = { val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1) From 5c78e1ac257fe4268004fd56231560c0e73493a7 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 11 May 2014 12:49:10 -0700 Subject: [PATCH 27/72] added multiclass support --- .../spark/mllib/tree/DecisionTree.scala | 64 +++++++++++-------- .../spark/mllib/tree/impurity/Entropy.scala | 25 ++++---- .../spark/mllib/tree/impurity/Gini.scala | 25 ++++---- .../spark/mllib/tree/impurity/Impurity.scala | 8 +-- .../spark/mllib/tree/impurity/Variance.scala | 2 +- 5 files changed, 68 insertions(+), 56 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 0ca4366ae6e8..3da92ed89161 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -681,36 +681,47 @@ object DecisionTree extends Serializable with Logging { topImpurity: Double): InformationGainStats = { strategy.algo match { case Classification => - // TODO: Modify here - val left0Count = leftNodeAgg(featureIndex)(splitIndex)(0) - val left1Count = leftNodeAgg(featureIndex)(splitIndex)(1) - val leftCount = left0Count + left1Count - - val right0Count = rightNodeAgg(featureIndex)(splitIndex)(0) - val right1Count = rightNodeAgg(featureIndex)(splitIndex)(1) - val rightCount = right0Count + right1Count + var classIndex = 0 + val leftCounts: Array[Double] = new Array[Double](numClasses) + val rightCounts: Array[Double] = new Array[Double](numClasses) + var leftTotalCount = 0.0 + var rightTotalCount = 0.0 + while (classIndex < numClasses) { + val leftClassCount = leftNodeAgg(featureIndex)(splitIndex)(classIndex) + val rightClassCount = rightNodeAgg(featureIndex)(splitIndex)(classIndex) + leftCounts(classIndex) = leftClassCount + leftTotalCount += leftClassCount + rightCounts(classIndex) = rightClassCount + rightTotalCount += rightClassCount + classIndex += 1 + } val impurity = { if (level > 0) { topImpurity } else { // Calculate impurity for root node. - strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count) + val rootNodeCounts = new Array[Double](numClasses) + var classIndex = 0 + while (classIndex < numClasses) { + rootNodeCounts(classIndex) = leftCounts(classIndex) + rightCounts(classIndex) + } + strategy.impurity.calculate(rootNodeCounts, leftTotalCount + rightTotalCount) } } - if (leftCount == 0) { + if (leftTotalCount == 0) { return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity,1) } - if (rightCount == 0) { + if (rightTotalCount == 0) { return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue,0) } - val leftImpurity = strategy.impurity.calculate(left0Count, left1Count) - val rightImpurity = strategy.impurity.calculate(right0Count, right1Count) + val leftImpurity = strategy.impurity.calculate(leftCounts, leftTotalCount) + val rightImpurity = strategy.impurity.calculate(rightCounts, rightTotalCount) - val leftWeight = leftCount.toDouble / (leftCount + rightCount) - val rightWeight = rightCount.toDouble / (leftCount + rightCount) + val leftWeight = leftTotalCount.toDouble / (leftTotalCount + rightTotalCount) + val rightWeight = rightTotalCount.toDouble / (leftTotalCount + rightTotalCount) val gain = { if (level > 0) { @@ -720,7 +731,8 @@ object DecisionTree extends Serializable with Logging { } } - val predict = (left1Count + right1Count) / (leftCount + rightCount) + //TODO: Make modification here + val predict = (leftCounts(1) + rightCounts(1)) / (leftTotalCount + rightTotalCount) new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) case Regression => @@ -782,7 +794,6 @@ object DecisionTree extends Serializable with Logging { binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = { strategy.algo match { case Classification => - // TODO: Multiclass modification here // Initialize left and right split aggregates. val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) @@ -793,17 +804,19 @@ object DecisionTree extends Serializable with Logging { while (featureIndex < numFeatures){ val numCategories = strategy.categoricalFeaturesInfo(featureIndex) val maxSplits = math.pow(2, numCategories) - 1 - var i = 0 - // TODO: Add multiclass case here - while (i < maxSplits) { + var splitIndex = 0 + while (splitIndex < maxSplits) { var classIndex = 0 while (classIndex < numClasses) { // shift for this featureIndex val shift = numClasses * featureIndex * numBins - + leftNodeAgg(featureIndex)(splitIndex)(classIndex) + = binData(shift + classIndex) + rightNodeAgg(featureIndex)(splitIndex)(classIndex) + = binData(shift + numClasses + classIndex) classIndex += 1 } - i += 1 + splitIndex += 1 } featureIndex += 1 } @@ -931,8 +944,6 @@ object DecisionTree extends Serializable with Logging { binData: Array[Double], nodeImpurity: Double): (Split, InformationGainStats) = { - // TODO: Multiclass modification here - logDebug("node impurity = " + nodeImpurity) // Extract left right node aggregates. @@ -977,9 +988,8 @@ object DecisionTree extends Serializable with Logging { def getBinDataForNode(node: Int): Array[Double] = { strategy.algo match { case Classification => - // TODO: Multiclass modification here - val shift = 2 * node * numBins * numFeatures - val binsForNode = binAggregates.slice(shift, shift + 2 * numBins * numFeatures) + val shift = numClasses * node * numBins * numFeatures + val binsForNode = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures) binsForNode case Regression => val shift = 3 * node * numBins * numFeatures diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 60f43e9278d2..6366960f39b0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -31,21 +31,22 @@ object Entropy extends Impurity { /** * :: DeveloperApi :: - * entropy calculation - * @param c0 count of instances with label 0 - * @param c1 count of instances with label 1 - * @return entropy value + * information calculation for multiclass classification + * @param counts Array[Double] with counts for each label + * @param totalCount sum of counts for all labels + * @return information value */ @DeveloperApi - override def calculate(c0: Double, c1: Double): Double = { - if (c0 == 0 || c1 == 0) { - 0 - } else { - val total = c0 + c1 - val f0 = c0 / total - val f1 = c1 / total - -(f0 * log2(f0)) - (f1 * log2(f1)) + override def calculate(counts: Array[Double], totalCount: Double): Double = { + val numClasses = counts.length + var impurity = 0.0 + var classIndex = 0 + while (classIndex < numClasses) { + val freq = counts(classIndex) / totalCount + impurity -= freq * log2(freq) + classIndex += 1 } + impurity } override def calculate(count: Double, sum: Double, sumSquares: Double): Double = diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index c51d76d9b4c5..c8773fc4f860 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -30,21 +30,22 @@ object Gini extends Impurity { /** * :: DeveloperApi :: - * Gini coefficient calculation - * @param c0 count of instances with label 0 - * @param c1 count of instances with label 1 - * @return Gini coefficient value + * information calculation for multiclass classification + * @param counts Array[Double] with counts for each label + * @param totalCount sum of counts for all labels + * @return information value */ @DeveloperApi - override def calculate(c0: Double, c1: Double): Double = { - if (c0 == 0 || c1 == 0) { - 0 - } else { - val total = c0 + c1 - val f0 = c0 / total - val f1 = c1 / total - 1 - f0 * f0 - f1 * f1 + override def calculate(counts: Array[Double], totalCount: Double): Double = { + val numClasses = counts.length + var impurity = 1.0 + var classIndex = 0 + while (classIndex < numClasses) { + val freq = counts(classIndex) / totalCount + impurity -= freq * freq + classIndex += 1 } + impurity } override def calculate(count: Double, sum: Double, sumSquares: Double): Double = diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 8eab247cf093..7b2a9320cc21 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -28,13 +28,13 @@ trait Impurity extends Serializable { /** * :: DeveloperApi :: - * information calculation for binary classification - * @param c0 count of instances with label 0 - * @param c1 count of instances with label 1 + * information calculation for multiclass classification + * @param counts Array[Double] with counts for each label + * @param totalCount sum of counts for all labels * @return information value */ @DeveloperApi - def calculate(c0 : Double, c1 : Double): Double + def calculate(counts: Array[Double], totalCount: Double): Double /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 47d07122af30..555754b1ee03 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -25,7 +25,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} */ @Experimental object Variance extends Impurity { - override def calculate(c0: Double, c1: Double): Double = + override def calculate(counts: Array[Double], totalCounts: Double): Double = throw new UnsupportedOperationException("Variance.calculate") /** From e006f9d5914b28b30aa8c24b0d1ff9977f23179e Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 11 May 2014 21:29:47 -0700 Subject: [PATCH 28/72] changing variable names --- .../spark/mllib/tree/DecisionTree.scala | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 3da92ed89161..52ae362028f5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -46,18 +46,15 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data * @return a DecisionTreeModel that can be used for prediction */ - def train(input: RDD[LabeledPoint]): DecisionTreeModel = { - - // Converting from standard instance format to weighted input format for tree training - val weightedInput = input.map(x => WeightedLabeledPoint(x.label,x.features)) + def train(input: RDD[WeightedLabeledPoint]): DecisionTreeModel = { // Cache input RDD for speedup during multiple passes. - weightedInput.cache() + input.cache() logDebug("algo = " + strategy.algo) // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. - val (splits, bins) = DecisionTree.findSplitsBins(weightedInput, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) val numBins = bins(0).length logDebug("numBins = " + numBins) @@ -74,7 +71,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // dummy value for top node (updated during first split calculation) val nodes = new Array[Node](maxNumNodes) // num features - val numFeatures = weightedInput.take(1)(0).features.size + val numFeatures = input.take(1)(0).features.size // Calculate level for single group construction @@ -113,7 +110,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo logDebug("#####################################") // Find best split for all nodes at a level. - val splitsStatsForLevel = DecisionTree.findBestSplits(weightedInput, parentImpurities, + val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, level, filters, splits, bins, maxLevelForSingleGroup) for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { @@ -216,7 +213,9 @@ object DecisionTree extends Serializable with Logging { * @return a DecisionTreeModel that can be used for prediction */ def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = { - new DecisionTree(strategy).train(input: RDD[LabeledPoint]) + // Converting from standard instance format to weighted input format for tree training + val weightedInput = input.map(x => WeightedLabeledPoint(x.label,x.features)) + new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint]) } /** @@ -238,7 +237,9 @@ object DecisionTree extends Serializable with Logging { impurity: Impurity, maxDepth: Int): DecisionTreeModel = { val strategy = new Strategy(algo,impurity,maxDepth) - new DecisionTree(strategy).train(input: RDD[LabeledPoint]) + // Converting from standard instance format to weighted input format for tree training + val weightedInput = input.map(x => WeightedLabeledPoint(x.label,x.features)) + new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint]) } @@ -273,7 +274,9 @@ object DecisionTree extends Serializable with Logging { categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth, maxBins, quantileCalculationStrategy, categoricalFeaturesInfo) - new DecisionTree(strategy).train(input: RDD[LabeledPoint]) + // Converting from standard instance format to weighted input format for tree training + val weightedInput = input.map(x => WeightedLabeledPoint(x.label,x.features)) + new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint]) } private val InvalidBinIndex = -1 From 34549d04c2ab36debd461c0e1671dcd4eb8bd270 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 11 May 2014 21:50:31 -0700 Subject: [PATCH 29/72] fixing error during merge --- .../main/scala/org/apache/spark/mllib/tree/DecisionTree.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index fa94cdfa9f27..52ae362028f5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -394,9 +394,6 @@ object DecisionTree extends Serializable with Logging { // shift when more than one group is used at deep tree level val groupShift = numNodes * groupIndex - // shift when more than one group is used at deep tree level - val groupShift = numNodes * groupIndex - /** Find the filters used before reaching the current code. */ def findParentFilters(nodeIndex: Int): List[Filter] = { if (level == 0) { From e54715199e576e2f6c3f71081c44c142407148d6 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 11 May 2014 22:18:26 -0700 Subject: [PATCH 30/72] minor modifications --- .../scala/org/apache/spark/mllib/tree/DecisionTree.scala | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 52ae362028f5..269a7a8c711e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -242,6 +242,9 @@ object DecisionTree extends Serializable with Logging { new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint]) } + // TODO: Add multiclass classification support + + // TODO: Add sample weight support /** * Method to train a decision tree model where the instances are represented as an RDD of @@ -723,8 +726,8 @@ object DecisionTree extends Serializable with Logging { val leftImpurity = strategy.impurity.calculate(leftCounts, leftTotalCount) val rightImpurity = strategy.impurity.calculate(rightCounts, rightTotalCount) - val leftWeight = leftTotalCount.toDouble / (leftTotalCount + rightTotalCount) - val rightWeight = rightTotalCount.toDouble / (leftTotalCount + rightTotalCount) + val leftWeight = leftTotalCount / (leftTotalCount + rightTotalCount) + val rightWeight = rightTotalCount / (leftTotalCount + rightTotalCount) val gain = { if (level > 0) { @@ -734,7 +737,7 @@ object DecisionTree extends Serializable with Logging { } } - //TODO: Make modification here + //TODO: Make multiclass modification here val predict = (leftCounts(1) + rightCounts(1)) / (leftTotalCount + rightTotalCount) new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) From 75f2bfc379f59063f55c194e9d5d1c07227a31d9 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 12 May 2014 13:06:31 -0700 Subject: [PATCH 31/72] minor code style fix --- .../scala/org/apache/spark/mllib/tree/DecisionTree.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 269a7a8c711e..4ccd6a3861be 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -214,7 +214,7 @@ object DecisionTree extends Serializable with Logging { */ def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = { // Converting from standard instance format to weighted input format for tree training - val weightedInput = input.map(x => WeightedLabeledPoint(x.label,x.features)) + val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features)) new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint]) } @@ -238,7 +238,7 @@ object DecisionTree extends Serializable with Logging { maxDepth: Int): DecisionTreeModel = { val strategy = new Strategy(algo,impurity,maxDepth) // Converting from standard instance format to weighted input format for tree training - val weightedInput = input.map(x => WeightedLabeledPoint(x.label,x.features)) + val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features)) new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint]) } @@ -278,7 +278,7 @@ object DecisionTree extends Serializable with Logging { val strategy = new Strategy(algo, impurity, maxDepth, maxBins, quantileCalculationStrategy, categoricalFeaturesInfo) // Converting from standard instance format to weighted input format for tree training - val weightedInput = input.map(x => WeightedLabeledPoint(x.label,x.features)) + val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features)) new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint]) } From 6b912dcb9324309b22f759894a544243d41d36b3 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 12 May 2014 13:59:23 -0700 Subject: [PATCH 32/72] added numclasses to tree runner, predict logic for multiclass, add multiclass option to train --- .../examples/mllib/DecisionTreeRunner.scala | 12 ++-- .../spark/mllib/tree/DecisionTree.scala | 65 +++++++++++++++---- .../mllib/tree/configuration/Strategy.scala | 8 +-- .../tree/model/InformationGainStats.scala | 8 ++- 4 files changed, 69 insertions(+), 24 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 9832bec90d7e..7c3cc5ee7579 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -49,6 +49,7 @@ object DecisionTreeRunner { case class Params( input: String = null, algo: Algo = Classification, + numClasses: Int = 2, maxDepth: Int = 5, impurity: ImpurityType = Gini, maxBins: Int = 100) @@ -68,6 +69,9 @@ object DecisionTreeRunner { opt[Int]("maxDepth") .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") .action((x, c) => c.copy(maxDepth = x)) + opt[Int]("numClasses") + .text(s"number of classes for classification, default: ${defaultParams.numClasses}") + .action((x, c) => c.copy(numClasses = x)) opt[Int]("maxBins") .text(s"max number of bins, default: ${defaultParams.maxBins}") .action((x, c) => c.copy(maxBins = x)) @@ -139,12 +143,8 @@ object DecisionTreeRunner { */ private def accuracyScore( model: DecisionTreeModel, - data: RDD[LabeledPoint], - threshold: Double = 0.5): Double = { - def predictedValue(features: Vector): Double = { - if (model.predict(features) < threshold) 0.0 else 1.0 - } - val correctCount = data.filter(y => predictedValue(y.features) == y.label).count() + data: RDD[LabeledPoint]): Double = { + val correctCount = data.filter(y => model.predict(y.features) == y.label).count() val count = data.count() correctCount.toDouble / count } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 4ccd6a3861be..2b8fbdc7a768 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -231,19 +231,43 @@ object DecisionTree extends Serializable with Logging { * @param maxDepth maxDepth maximum depth of the tree * @return a DecisionTreeModel that can be used for prediction */ + def train( + input: RDD[LabeledPoint], + algo: Algo, + impurity: Impurity, + maxDepth: Int): DecisionTreeModel = { + val strategy = new Strategy(algo,impurity,maxDepth) + // Converting from standard instance format to weighted input format for tree training + val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features)) + new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint]) + } + + /** + * Method to train a decision tree model where the instances are represented as an RDD of + * (label, features) pairs. The method supports binary classification and regression. For the + * binary classification, the label for each instance should either be 0 or 1 to denote the two + * classes. + * + * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as + * training data + * @param algo algorithm, classification or regression + * @param impurity impurity criterion used for information gain calculation + * @param maxDepth maxDepth maximum depth of the tree + * @param numClasses number of classes for classification + * @return a DecisionTreeModel that can be used for prediction + */ def train( input: RDD[LabeledPoint], algo: Algo, impurity: Impurity, - maxDepth: Int): DecisionTreeModel = { - val strategy = new Strategy(algo,impurity,maxDepth) + maxDepth: Int, + numClasses: Int): DecisionTreeModel = { + val strategy = new Strategy(algo,impurity,maxDepth,numClasses) // Converting from standard instance format to weighted input format for tree training val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features)) new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint]) } - // TODO: Add multiclass classification support - // TODO: Add sample weight support /** @@ -258,6 +282,7 @@ object DecisionTree extends Serializable with Logging { * @param algo classification or regression * @param impurity criterion used for information gain calculation * @param maxDepth maximum depth of the tree + * @param numClasses number of classes for classification * @param maxBins maximum number of bins used for splitting features * @param quantileCalculationStrategy algorithm for calculating quantiles * @param categoricalFeaturesInfo A map storing information about the categorical variables and @@ -272,11 +297,12 @@ object DecisionTree extends Serializable with Logging { algo: Algo, impurity: Impurity, maxDepth: Int, + numClasses: Int, maxBins: Int, quantileCalculationStrategy: QuantileStrategy, categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = { - val strategy = new Strategy(algo, impurity, maxDepth, maxBins, quantileCalculationStrategy, - categoricalFeaturesInfo) + val strategy = new Strategy(algo, impurity, maxDepth, numClasses, maxBins, + quantileCalculationStrategy, categoricalFeaturesInfo) // Converting from standard instance format to weighted input format for tree training val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features)) new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint]) @@ -737,10 +763,26 @@ object DecisionTree extends Serializable with Logging { } } - //TODO: Make multiclass modification here - val predict = (leftCounts(1) + rightCounts(1)) / (leftTotalCount + rightTotalCount) + val totalCount = leftTotalCount + rightTotalCount - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) + // Sum of count for each label + val leftRightCounts: Array[Double] + = leftCounts.zip(rightCounts) + .map{case (leftCount, rightCount) => leftCount + rightCount} + + def indexOfLargest(array: Seq[Double]): Int = { + val result = array.foldLeft(-1,Double.MinValue,0) { + case ((maxIndex, maxValue, currentIndex), currentValue) => + if(currentValue > maxValue) (currentIndex,currentValue,currentIndex+1) + else (maxIndex,maxValue,currentIndex+1) + } + if (result._1 < 0) result._1 else 0 + } + + val predict = indexOfLargest(leftRightCounts) + val prob = leftRightCounts(predict) / totalCount + + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) case Regression => val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0) val leftSum = leftNodeAgg(featureIndex)(splitIndex)(1) @@ -793,8 +835,9 @@ object DecisionTree extends Serializable with Logging { /** * Extracts left and right split aggregates. * @param binData Array[Double] of size 2*numFeatures*numSplits - * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Double], - * Array[Double]) where each array is of size(numFeature,2*(numSplits-1)) + * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Array[Array[Double\]\]\], + * Array[Array[Array[Double\]\]\]) where each array is of size(numFeature, + * (numBins - 1), numClasses) */ def extractLeftRightNodeAggregates( binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 3aa2d5382cb8..c397a889f260 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -28,6 +28,8 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * @param algo classification or regression * @param impurity criterion used for information gain calculation * @param maxDepth maximum depth of the tree + * @param numClassesForClassification number of classes for classification. Default value is 2 + * leads to binary classification * @param maxBins maximum number of bins used for splitting features * @param quantileCalculationStrategy algorithm for calculating quantiles * @param categoricalFeaturesInfo A map storing information about the categorical variables and the @@ -37,8 +39,6 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * zero-indexed. * @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is * 128 MB. - * @param numClassesForClassification number of classes for classification. Default value is 2 - * leads to binary classification * */ @Experimental @@ -46,11 +46,11 @@ class Strategy ( val algo: Algo, val impurity: Impurity, val maxDepth: Int, + val numClassesForClassification: Int = 2, val maxBins: Int = 100, val quantileCalculationStrategy: QuantileStrategy = Sort, val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), - val maxMemoryInMB: Int = 128, - val numClassesForClassification: Int = 2) extends Serializable { + val maxMemoryInMB: Int = 128) extends Serializable { require(numClassesForClassification >= 2) val isMultiClassification = numClassesForClassification > 2 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index cc8a24cce961..fb12298e0f5d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -27,6 +27,7 @@ import org.apache.spark.annotation.DeveloperApi * @param leftImpurity left node impurity * @param rightImpurity right node impurity * @param predict predicted value + * @param prob probability of the label (classification only) */ @DeveloperApi class InformationGainStats( @@ -34,10 +35,11 @@ class InformationGainStats( val impurity: Double, val leftImpurity: Double, val rightImpurity: Double, - val predict: Double) extends Serializable { + val predict: Double, + val prob: Double = 0.0) extends Serializable { override def toString = { - "gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f" - .format(gain, impurity, leftImpurity, rightImpurity, predict) + "gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f, prob = %f" + .format(gain, impurity, leftImpurity, rightImpurity, predict, prob) } } From 18d283506f311939b4026da63214a559726bc5e4 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 12 May 2014 15:33:10 -0700 Subject: [PATCH 33/72] changing default values for num classes --- .../examples/mllib/DecisionTreeRunner.scala | 17 ++++++++++++----- .../apache/spark/mllib/tree/DecisionTree.scala | 18 ++++++++---------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 7c3cc5ee7579..22c344a7dab9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -49,7 +49,7 @@ object DecisionTreeRunner { case class Params( input: String = null, algo: Algo = Classification, - numClasses: Int = 2, + numClassesForClassification: Int = 2, maxDepth: Int = 5, impurity: ImpurityType = Gini, maxBins: Int = 100) @@ -69,9 +69,10 @@ object DecisionTreeRunner { opt[Int]("maxDepth") .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") .action((x, c) => c.copy(maxDepth = x)) - opt[Int]("numClasses") - .text(s"number of classes for classification, default: ${defaultParams.numClasses}") - .action((x, c) => c.copy(numClasses = x)) + opt[Int]("numClassesForClassification") + .text(s"number of classes for classification, " + + s"default: ${defaultParams.numClassesForClassification}") + .action((x, c) => c.copy(numClassesForClassification = x)) opt[Int]("maxBins") .text(s"max number of bins, default: ${defaultParams.maxBins}") .action((x, c) => c.copy(maxBins = x)) @@ -122,7 +123,13 @@ object DecisionTreeRunner { case Variance => impurity.Variance } - val strategy = new Strategy(params.algo, impurityCalculator, params.maxDepth, params.maxBins) + val strategy + = new Strategy( + algo = params.algo, + impurity = impurityCalculator, + maxDepth = params.maxDepth, + maxBins = params.maxBins, + numClassesForClassification = params.numClassesForClassification) val model = DecisionTree.train(training, strategy) if (params.algo == Classification) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 2b8fbdc7a768..86709c293bfc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -236,7 +236,7 @@ object DecisionTree extends Serializable with Logging { algo: Algo, impurity: Impurity, maxDepth: Int): DecisionTreeModel = { - val strategy = new Strategy(algo,impurity,maxDepth) + val strategy = new Strategy(algo, impurity, maxDepth) // Converting from standard instance format to weighted input format for tree training val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features)) new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint]) @@ -253,7 +253,7 @@ object DecisionTree extends Serializable with Logging { * @param algo algorithm, classification or regression * @param impurity impurity criterion used for information gain calculation * @param maxDepth maxDepth maximum depth of the tree - * @param numClasses number of classes for classification + * @param numClassesForClassification number of classes for classification. Default value of 2. * @return a DecisionTreeModel that can be used for prediction */ def train( @@ -261,8 +261,8 @@ object DecisionTree extends Serializable with Logging { algo: Algo, impurity: Impurity, maxDepth: Int, - numClasses: Int): DecisionTreeModel = { - val strategy = new Strategy(algo,impurity,maxDepth,numClasses) + numClassesForClassification: Int): DecisionTreeModel = { + val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification) // Converting from standard instance format to weighted input format for tree training val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features)) new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint]) @@ -282,7 +282,7 @@ object DecisionTree extends Serializable with Logging { * @param algo classification or regression * @param impurity criterion used for information gain calculation * @param maxDepth maximum depth of the tree - * @param numClasses number of classes for classification + * @param numClassesForClassification number of classes for classification. Default value of 2. * @param maxBins maximum number of bins used for splitting features * @param quantileCalculationStrategy algorithm for calculating quantiles * @param categoricalFeaturesInfo A map storing information about the categorical variables and @@ -297,11 +297,11 @@ object DecisionTree extends Serializable with Logging { algo: Algo, impurity: Impurity, maxDepth: Int, - numClasses: Int, + numClassesForClassification: Int, maxBins: Int, quantileCalculationStrategy: QuantileStrategy, categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = { - val strategy = new Strategy(algo, impurity, maxDepth, numClasses, maxBins, + val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins, quantileCalculationStrategy, categoricalFeaturesInfo) // Converting from standard instance format to weighted input format for tree training val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features)) @@ -851,10 +851,8 @@ object DecisionTree extends Serializable with Logging { if (strategy.isMultiClassification) { var featureIndex = 0 while (featureIndex < numFeatures){ - val numCategories = strategy.categoricalFeaturesInfo(featureIndex) - val maxSplits = math.pow(2, numCategories) - 1 var splitIndex = 0 - while (splitIndex < maxSplits) { + while (splitIndex < numBins - 1) { var classIndex = 0 while (classIndex < numClasses) { // shift for this featureIndex From d012be77263ac82d71ac2f3cf7e78c756fe4b278 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 12 May 2014 16:46:46 -0700 Subject: [PATCH 34/72] fixed while loop --- .../main/scala/org/apache/spark/mllib/tree/DecisionTree.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 86709c293bfc..1745a4b09e3d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -737,6 +737,7 @@ object DecisionTree extends Serializable with Logging { var classIndex = 0 while (classIndex < numClasses) { rootNodeCounts(classIndex) = leftCounts(classIndex) + rightCounts(classIndex) + classIndex += 1 } strategy.impurity.calculate(rootNodeCounts, leftTotalCount + rightTotalCount) } @@ -1054,7 +1055,7 @@ object DecisionTree extends Serializable with Logging { val binsForNode: Array[Double] = getBinDataForNode(node) logDebug("nodeImpurityIndex = " + nodeImpurityIndex) val parentNodeImpurity = parentImpurities(nodeImpurityIndex) - logDebug("node impurity = " + parentNodeImpurity) + logDebug("parent node impurity = " + parentNodeImpurity) bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity) node += 1 } From ed5a2dffbf3e91215ef220ee8cf4594bb4898a9b Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 12 May 2014 20:45:22 -0700 Subject: [PATCH 35/72] fixed classification requirements --- .../main/scala/org/apache/spark/mllib/tree/DecisionTree.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 1745a4b09e3d..072651dbf173 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -1093,9 +1093,9 @@ object DecisionTree extends Serializable with Logging { */ if (strategy.categoricalFeaturesInfo.size > 0) { val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2 - require(numBins >= maxCategoriesForFeatures) + require(numBins > maxCategoriesForFeatures) if (strategy.isMultiClassification) { - require(numBins > math.pow(2, maxCategoriesForFeatures.toInt) - 1) + require(numBins > math.pow(2, maxCategoriesForFeatures.toInt - 1) - 1) } } From d8e4a11833b5a7e5a6e4f0f72d203fbf8e0bb0ed Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Tue, 13 May 2014 09:20:40 -0700 Subject: [PATCH 36/72] sample weights --- .../spark/mllib/tree/DecisionTree.scala | 48 +++++++++++++++++-- .../mllib/tree/configuration/Strategy.scala | 7 ++- 2 files changed, 51 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 072651dbf173..c467d5ba65d9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -268,7 +268,39 @@ object DecisionTree extends Serializable with Logging { new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint]) } - // TODO: Add sample weight support + + /** + * Method to train a decision tree model where the instances are represented as an RDD of + * (label, features) pairs. The method supports binary classification and regression. For the + * binary classification, the label for each instance should either be 0 or 1 to denote the two + * classes. + * + * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as + * training data + * @param algo algorithm, classification or regression + * @param impurity impurity criterion used for information gain calculation + * @param maxDepth maxDepth maximum depth of the tree + * @param numClassesForClassification number of classes for classification. Default value of 2. + * @param labelWeights A map storing weights applied to each label for handling unbalanced + * datasets. For example, an entry (n -> k) implies the a weight of k is + * applied to an instance with label n. It's important to note that labels + * are zero-index and take values 0, 1, 2, ... , numClasses. + * @return a DecisionTreeModel that can be used for prediction + */ + def train( + input: RDD[LabeledPoint], + algo: Algo, + impurity: Impurity, + maxDepth: Int, + numClassesForClassification: Int, + labelWeights: Map[Int,Int]): DecisionTreeModel = { + val strategy + = new Strategy(algo, impurity, maxDepth, numClassesForClassification, + labelWeights = labelWeights) + // Converting from standard instance format to weighted input format for tree training + val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features)) + new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint]) + } /** * Method to train a decision tree model where the instances are represented as an RDD of @@ -283,6 +315,10 @@ object DecisionTree extends Serializable with Logging { * @param impurity criterion used for information gain calculation * @param maxDepth maximum depth of the tree * @param numClassesForClassification number of classes for classification. Default value of 2. + * @param labelWeights A map storing weights applied to each label for handling unbalanced + * datasets. For example, an entry (n -> k) implies the a weight of k is + * applied to an instance with label n. It's important to note that labels + * are zero-index and take values 0, 1, 2, ... , numClasses. * @param maxBins maximum number of bins used for splitting features * @param quantileCalculationStrategy algorithm for calculating quantiles * @param categoricalFeaturesInfo A map storing information about the categorical variables and @@ -298,11 +334,12 @@ object DecisionTree extends Serializable with Logging { impurity: Impurity, maxDepth: Int, numClassesForClassification: Int, + labelWeights: Map[Int,Int], maxBins: Int, quantileCalculationStrategy: QuantileStrategy, categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins, - quantileCalculationStrategy, categoricalFeaturesInfo) + quantileCalculationStrategy, categoricalFeaturesInfo, labelWeights = labelWeights) // Converting from standard instance format to weighted input format for tree training val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features)) new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint]) @@ -419,6 +456,9 @@ object DecisionTree extends Serializable with Logging { logDebug("numBins = " + numBins) val numClasses = strategy.numClassesForClassification logDebug("numClasses = " + numClasses) + val labelWeights = strategy.labelWeights + logDebug("labelWeights = " + labelWeights) + // shift when more than one group is used at deep tree level val groupShift = numNodes * groupIndex @@ -605,7 +645,8 @@ object DecisionTree extends Serializable with Logging { val aggIndex = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses label.toInt match { - case n: Int => agg(aggIndex + n) = agg(aggIndex + n) + 1 + case n: Int => + agg(aggIndex + n) = agg(aggIndex + n) + 1 * labelWeights.getOrElse(n, 1) } featureIndex += 1 } @@ -1010,6 +1051,7 @@ object DecisionTree extends Serializable with Logging { while (featureIndex < numFeatures) { // Iterate over all splits. var splitIndex = 0 + // TODO: Modify this for categorical variables to go over only valid splits while (splitIndex < numBins - 1) { val gainStats = gains(featureIndex)(splitIndex) if (gainStats.gain > bestGainStats.gain) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index c397a889f260..89daaaeccdca 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -39,6 +39,10 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * zero-indexed. * @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is * 128 MB. + * @param labelWeights A map storing weights applied to each label for handling unbalanced + * datasets. For example, an entry (n -> k) implies the a weight of k is + * applied to an instance with label n. It's important to note that labels + * are zero-index and take values 0, 1, 2, ... , numClasses. * */ @Experimental @@ -50,7 +54,8 @@ class Strategy ( val maxBins: Int = 100, val quantileCalculationStrategy: QuantileStrategy = Sort, val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), - val maxMemoryInMB: Int = 128) extends Serializable { + val maxMemoryInMB: Int = 128, + val labelWeights: Map[Int, Int] = Map[Int, Int]()) extends Serializable { require(numClassesForClassification >= 2) val isMultiClassification = numClassesForClassification > 2 From ab5cb21df99a10acd15f92413eb43bddbfd15cfd Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sat, 17 May 2014 19:18:55 -0700 Subject: [PATCH 37/72] multiclass logic --- .../spark/mllib/tree/DecisionTree.scala | 94 ++++++++++++++----- .../spark/mllib/tree/impurity/Entropy.scala | 7 +- .../apache/spark/mllib/tree/model/Bin.scala | 2 +- .../spark/mllib/tree/DecisionTreeSuite.scala | 45 ++++----- 4 files changed, 94 insertions(+), 54 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index c467d5ba65d9..fb752b06380e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -545,17 +545,24 @@ object DecisionTree extends Serializable with Logging { -1 } + /** + * Sequential search helper method to find bin for categorical feature in multiclass + * classification. Dummy value of 0 used since it is not used in future calculation + */ + def sequentialBinSearchForCategoricalFeatureInBinaryClassification(): Int = 0 + /** * Sequential search helper method to find bin for categorical feature. */ - def sequentialBinSearchForCategoricalFeature(): Int = { - val numCategoricalBins = strategy.categoricalFeaturesInfo(featureIndex) + def sequentialBinSearchForCategoricalFeatureInMultiClassClassification(): Int = { + val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) + val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 var binIndex = 0 while (binIndex < numCategoricalBins) { val bin = bins(featureIndex)(binIndex) - val category = bin.category + val categories = bin.highSplit.categories val features = labeledPoint.features - if (category == features(featureIndex)) { + if (categories.contains(features(featureIndex))) { return binIndex } binIndex += 1 @@ -572,7 +579,14 @@ object DecisionTree extends Serializable with Logging { binIndex } else { // Perform sequential search to find bin for categorical features. - val binIndex = sequentialBinSearchForCategoricalFeature() + val binIndex = { + if (strategy.isMultiClassification) { + sequentialBinSearchForCategoricalFeatureInBinaryClassification() + } + else { + sequentialBinSearchForCategoricalFeatureInMultiClassClassification() + } + } if (binIndex == -1){ throw new UnknownError("no bin was found for categorical variable.") } @@ -584,7 +598,8 @@ object DecisionTree extends Serializable with Logging { * Finds bins for all nodes (and all features) at a given level. * For l nodes, k features the storage is as follows: * label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. , b_lk, - * where b_ij is an integer between 0 and numBins - 1. + * where b_ij is an integer between 0 and numBins - 1 for regressions and binary + * classification and an invalid value for categorical feature in multiclass classification. * Invalid sample is denoted by noting bin for feature 1 as -1. */ def findBinsForLevel(labeledPoint: WeightedLabeledPoint): Array[Double] = { @@ -646,7 +661,22 @@ object DecisionTree extends Serializable with Logging { = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses label.toInt match { case n: Int => - agg(aggIndex + n) = agg(aggIndex + n) + 1 * labelWeights.getOrElse(n, 1) + val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinuous && strategy.isMultiClassification) { + // Find all matching bins and increment their values + val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) + val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 + var binIndex = 0 + while (binIndex < numCategoricalBins) { + if (bins(featureIndex)(binIndex).highSplit.categories.contains(n)){ + agg(aggIndex + binIndex) + = agg(aggIndex + binIndex) + labelWeights.getOrElse(binIndex, 1) + } + binIndex += 1 + } + } else { + agg(aggIndex + n) = agg(aggIndex + n) + labelWeights.getOrElse(n, 1) + } } featureIndex += 1 } @@ -705,6 +735,7 @@ object DecisionTree extends Serializable with Logging { agg } + // TODO: Double-check this // Calculate bin aggregate length for classification or regression. val binAggregateLength = strategy.algo match { case Classification => numClasses * numBins * numFeatures * numNodes @@ -785,10 +816,10 @@ object DecisionTree extends Serializable with Logging { } if (leftTotalCount == 0) { - return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity,1) + return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue, 1) } if (rightTotalCount == 0) { - return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue,0) + return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, 1) } val leftImpurity = strategy.impurity.calculate(leftCounts, leftTotalCount) @@ -812,16 +843,16 @@ object DecisionTree extends Serializable with Logging { = leftCounts.zip(rightCounts) .map{case (leftCount, rightCount) => leftCount + rightCount} - def indexOfLargest(array: Seq[Double]): Int = { + def indexOfLargestArrayElement(array: Array[Double]): Int = { val result = array.foldLeft(-1,Double.MinValue,0) { case ((maxIndex, maxValue, currentIndex), currentValue) => if(currentValue > maxValue) (currentIndex,currentValue,currentIndex+1) else (maxIndex,maxValue,currentIndex+1) } - if (result._1 < 0) result._1 else 0 + if (result._1 < 0) 0 else result._1 } - val predict = indexOfLargest(leftRightCounts) + val predict = indexOfLargestArrayElement(leftRightCounts) val prob = leftRightCounts(predict) / totalCount new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) @@ -1051,8 +1082,20 @@ object DecisionTree extends Serializable with Logging { while (featureIndex < numFeatures) { // Iterate over all splits. var splitIndex = 0 - // TODO: Modify this for categorical variables to go over only valid splits - while (splitIndex < numBins - 1) { + val maxSplitIndex : Double = { + val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinuous) { + numBins - 1 + } else { // Categorical feature + val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) + if (strategy.isMultiClassification) { + math.pow(2.0, featureCategories - 1).toInt - 1 + } else { // Binary classification + featureCategories + } + } + } + while (splitIndex < maxSplitIndex) { val gainStats = gains(featureIndex)(splitIndex) if (gainStats.gain > bestGainStats.gain) { bestGainStats = gainStats @@ -1176,24 +1219,29 @@ object DecisionTree extends Serializable with Logging { splits(featureIndex)(index) = split } } else { // Categorical feature - val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex) + val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) // Use different bin/split calculation strategy for multiclass classification if (strategy.isMultiClassification) { - // Iterate from 0 to 2^maxFeatureValue - 1 leading to 2^(maxFeatureValue- 1) - 1 - // combinations. + // 2^(maxFeatureValue- 1) - 1 combinations var index = 0 - while (index < math.pow(2.0, maxFeatureValue).toInt - 1) { + while (index < math.pow(2.0, featureCategories - 1).toInt - 1) { val categories: List[Double] - = extractMultiClassCategories(index + 1, maxFeatureValue) + = extractMultiClassCategories(index + 1, featureCategories) splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue, Categorical, categories) bins(featureIndex)(index) = { if (index == 0) { - new Bin(new DummyCategoricalSplit(featureIndex, Categorical), - splits(featureIndex)(0), Categorical, Double.MinValue) + new Bin( + new DummyCategoricalSplit(featureIndex, Categorical), + splits(featureIndex)(0), + Categorical, + Double.MinValue) } else { - new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), Categorical, + new Bin( + splits(featureIndex)(index - 1), + splits(featureIndex)(index), + Categorical, Double.MinValue) } } @@ -1210,7 +1258,7 @@ object DecisionTree extends Serializable with Logging { // Check for missing categorical variables and putting them last in the sorted list. val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]() - for (i <- 0 until maxFeatureValue) { + for (i <- 0 until featureCategories) { if (centroidForCategories.contains(i)) { fullCentroidForCategories(i) = centroidForCategories(i) } else { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 6366960f39b0..ead76d64b638 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -42,8 +42,11 @@ object Entropy extends Impurity { var impurity = 0.0 var classIndex = 0 while (classIndex < numClasses) { - val freq = counts(classIndex) / totalCount - impurity -= freq * log2(freq) + val classCount = counts(classIndex) + if (classCount != 0) { + val freq = classCount / totalCount + impurity -= freq * log2(freq) + } classIndex += 1 } impurity diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala index 2d71e1e36606..c89c1e371a40 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala @@ -28,7 +28,7 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._ * @param highSplit signifying the upper threshold for the continuous feature to be * accepted in the bin * @param featureType type of feature -- categorical or continuous - * @param category categorical label value accepted in the bin + * @param category categorical label value accepted in the bin for binary classification */ private[tree] case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index b5259d8e1882..e7a55d52e736 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -35,7 +35,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Gini, 3, 100) + val strategy = new Strategy(Classification, Gini, 3, 2, 100) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) assert(splits.length === 2) assert(bins.length === 2) @@ -51,6 +51,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { Classification, Gini, maxDepth = 3, + numClassesForClassification = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) @@ -130,6 +131,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { Classification, Gini, maxDepth = 3, + numClassesForClassification = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) @@ -237,7 +239,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(List(3.0, 2.0, 0.0).toSeq == l.toSeq) } - test("split and bin calculations for categorical variables wiht multiclass classification") { + test("split and bin calculations for categorical variables with multiclass classification") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -245,12 +247,12 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { Classification, Gini, maxDepth = 3, + numClassesForClassification = 100, maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 2, 1-> 2), - numClassesForClassification = 3) + categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) - // Expecting 2^3 - 1 = 7 bins/splits + // Expecting 2^2 - 1 = 3 bins/splits assert(splits(0)(0).feature === 0) assert(splits(0)(0).threshold === Double.MinValue) assert(splits(0)(0).featureType === Categorical) @@ -287,6 +289,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(1)(2).categories.contains(1.0)) assert(splits(0)(3) === null) + assert(splits(1)(3) === null) // Check bins. @@ -329,22 +332,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { } - test("split and bin calculations for categorical variables with no sample for one category " + - "for multiclass classification") { - val arr = DecisionTreeSuite.generateCategoricalDataPoints() - assert(arr.length === 1000) - val rdd = sc.parallelize(arr) - val strategy = new Strategy( - Classification, - Gini, - maxDepth = 3, - maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 3, 1-> 3), - numClassesForClassification = 3) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) - - } - test("classification stump with all categorical variables") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) @@ -352,6 +339,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategy = new Strategy( Classification, Gini, + numClassesForClassification = 2, maxDepth = 3, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) @@ -367,8 +355,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val stats = bestSplits(0)._2 assert(stats.gain > 0) - assert(stats.predict > 0.4) - assert(stats.predict < 0.5) + assert(stats.predict === 0) + assert(stats.prob > 0.5) + assert(stats.prob < 0.6) assert(stats.impurity > 0.2) } @@ -403,7 +392,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Gini, 3, 100) + val strategy = new Strategy(Classification, Gini, 3, 2, 100) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) assert(splits.length === 2) assert(splits(0).length === 99) @@ -426,7 +415,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Gini, 3, 100) + val strategy = new Strategy(Classification, Gini, 3, 2, 100) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) assert(splits.length === 2) assert(splits(0).length === 99) @@ -450,7 +439,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Entropy, 3, 100) + val strategy = new Strategy(Classification, Entropy, 3, 2, 100) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) assert(splits.length === 2) assert(splits(0).length === 99) @@ -474,7 +463,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Entropy, 3, 100) + val strategy = new Strategy(Classification, Entropy, 3, 2, 100) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) assert(splits.length === 2) assert(splits(0).length === 99) @@ -498,7 +487,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val arr = DecisionTreeSuite.generateOrderedLabeledPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Entropy, 3, 100) + val strategy = new Strategy(Classification, Entropy, 3, 2, 100) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) assert(splits.length === 2) assert(splits(0).length === 99) From d811425aae6f5b62fb972b23fcbeb97604700c85 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sat, 17 May 2014 23:44:09 -0700 Subject: [PATCH 38/72] multiclass bin aggregate logic --- .../org/apache/spark/mllib/tree/DecisionTree.scala | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index fb752b06380e..6539a1194116 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -926,16 +926,22 @@ object DecisionTree extends Serializable with Logging { while (featureIndex < numFeatures){ var splitIndex = 0 while (splitIndex < numBins - 1) { + val totalNodeAgg = Array.ofDim[Double](numClasses) var classIndex = 0 while (classIndex < numClasses) { // shift for this featureIndex val shift = numClasses * featureIndex * numBins - leftNodeAgg(featureIndex)(splitIndex)(classIndex) - = binData(shift + classIndex) - rightNodeAgg(featureIndex)(splitIndex)(classIndex) - = binData(shift + numClasses + classIndex) + val binValue = binData(shift + classIndex) + leftNodeAgg(featureIndex)(splitIndex)(classIndex) = binValue + totalNodeAgg(classIndex) = binValue classIndex += 1 } + // Calculate rightNodeAgg + classIndex = 0 + while (classIndex < numClasses) { + rightNodeAgg(featureIndex)(splitIndex)(classIndex) + = totalNodeAgg(classIndex) - leftNodeAgg(featureIndex)(splitIndex)(classIndex) + } splitIndex += 1 } featureIndex += 1 From f16a9bb09d6585fa0963b3753f58cd679a4a2f32 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sat, 17 May 2014 23:45:04 -0700 Subject: [PATCH 39/72] fixing while loop --- .../main/scala/org/apache/spark/mllib/tree/DecisionTree.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 6539a1194116..9ebb1d25ffa0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -941,6 +941,7 @@ object DecisionTree extends Serializable with Logging { while (classIndex < numClasses) { rightNodeAgg(featureIndex)(splitIndex)(classIndex) = totalNodeAgg(classIndex) - leftNodeAgg(featureIndex)(splitIndex)(classIndex) + classIndex += 1 } splitIndex += 1 } From 1dd2735d095a46c19a1811c22a65ca211268eedd Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sat, 17 May 2014 23:50:40 -0700 Subject: [PATCH 40/72] bin search logic for multiclass --- .../scala/org/apache/spark/mllib/tree/DecisionTree.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 9ebb1d25ffa0..f1a3aea1f8c6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -549,7 +549,9 @@ object DecisionTree extends Serializable with Logging { * Sequential search helper method to find bin for categorical feature in multiclass * classification. Dummy value of 0 used since it is not used in future calculation */ - def sequentialBinSearchForCategoricalFeatureInBinaryClassification(): Int = 0 + def sequentialBinSearchForCategoricalFeatureInBinaryClassification(): Int = { + labeledPoint.features(featureIndex).toInt + } /** * Sequential search helper method to find bin for categorical feature. @@ -662,7 +664,7 @@ object DecisionTree extends Serializable with Logging { label.toInt match { case n: Int => val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (isFeatureContinuous && strategy.isMultiClassification) { + if (!isFeatureContinuous && strategy.isMultiClassification) { // Find all matching bins and increment their values val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 From 7e5f08c7b5938e8f777693a6257e855aff69b2a9 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 18 May 2014 11:17:18 -0700 Subject: [PATCH 41/72] minor doc --- .../scala/org/apache/spark/mllib/tree/DecisionTree.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index f1a3aea1f8c6..dd00519e652d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -584,8 +584,7 @@ object DecisionTree extends Serializable with Logging { val binIndex = { if (strategy.isMultiClassification) { sequentialBinSearchForCategoricalFeatureInBinaryClassification() - } - else { + } else { sequentialBinSearchForCategoricalFeatureInMultiClassClassification() } } @@ -601,7 +600,7 @@ object DecisionTree extends Serializable with Logging { * For l nodes, k features the storage is as follows: * label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. , b_lk, * where b_ij is an integer between 0 and numBins - 1 for regressions and binary - * classification and an invalid value for categorical feature in multiclass classification. + * classification and the categorical feature value in multiclass classification. * Invalid sample is denoted by noting bin for feature 1 as -1. */ def findBinsForLevel(labeledPoint: WeightedLabeledPoint): Array[Double] = { From bce835fb689f6d06405ec4e2f0cace5056e01492 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 18 May 2014 11:42:39 -0700 Subject: [PATCH 42/72] code cleanup --- .../spark/mllib/tree/DecisionTree.scala | 30 ++++++++++++------- .../mllib/tree/configuration/Strategy.scala | 2 +- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index dd00519e652d..861b35124368 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -144,6 +144,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo new DecisionTreeModel(topNode, strategy.algo) } + // TODO: Unit test this /** * Extract the decision tree node information for the given tree level and node index */ @@ -161,6 +162,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo nodes(nodeIndex) = node } + // TODO: Unit test this /** * Extract the decision tree node information for the children of the node */ @@ -458,6 +460,8 @@ object DecisionTree extends Serializable with Logging { logDebug("numClasses = " + numClasses) val labelWeights = strategy.labelWeights logDebug("labelWeights = " + labelWeights) + val isMulticlassClassification = strategy.isMulticlassClassification + logDebug("isMulticlassClassification = " + isMulticlassClassification) // shift when more than one group is used at deep tree level @@ -582,7 +586,7 @@ object DecisionTree extends Serializable with Logging { } else { // Perform sequential search to find bin for categorical features. val binIndex = { - if (strategy.isMultiClassification) { + if (isMulticlassClassification) { sequentialBinSearchForCategoricalFeatureInBinaryClassification() } else { sequentialBinSearchForCategoricalFeatureInMultiClassClassification() @@ -606,7 +610,9 @@ object DecisionTree extends Serializable with Logging { def findBinsForLevel(labeledPoint: WeightedLabeledPoint): Array[Double] = { // Calculate bin index and label per feature per node. val arr = new Array[Double](1 + (numFeatures * numNodes)) + // First element of the array is the label of the instance. arr(0) = labeledPoint.label + // Iterate over nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { val parentFilters = findParentFilters(nodeIndex) @@ -629,7 +635,10 @@ object DecisionTree extends Serializable with Logging { arr } - /** + // Find feature bins for all nodes at a level. + val binMappedRDD = input.map(x => findBinsForLevel(x)) + + /** * Performs a sequential aggregation over a partition for classification. For l nodes, * k features, either the left count or the right count of one of the p bins is * incremented based upon whether the feature is classified as 0 or 1. @@ -663,7 +672,7 @@ object DecisionTree extends Serializable with Logging { label.toInt match { case n: Int => val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (!isFeatureContinuous && strategy.isMultiClassification) { + if (!isFeatureContinuous && isMulticlassClassification) { // Find all matching bins and increment their values val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 @@ -736,7 +745,6 @@ object DecisionTree extends Serializable with Logging { agg } - // TODO: Double-check this // Calculate bin aggregate length for classification or regression. val binAggregateLength = strategy.algo match { case Classification => numClasses * numBins * numFeatures * numNodes @@ -760,9 +768,6 @@ object DecisionTree extends Serializable with Logging { combinedAggregate } - // Find feature bins for all nodes at a level. - val binMappedRDD = input.map(x => findBinsForLevel(x)) - // Calculate bin aggregates. val binAggregates = { binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp) @@ -922,7 +927,7 @@ object DecisionTree extends Serializable with Logging { val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) - if (strategy.isMultiClassification) { + if (isMulticlassClassification) { var featureIndex = 0 while (featureIndex < numFeatures){ var splitIndex = 0 @@ -1096,7 +1101,7 @@ object DecisionTree extends Serializable with Logging { numBins - 1 } else { // Categorical feature val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - if (strategy.isMultiClassification) { + if (isMulticlassClassification) { math.pow(2.0, featureCategories - 1).toInt - 1 } else { // Binary classification featureCategories @@ -1177,6 +1182,9 @@ object DecisionTree extends Serializable with Logging { val maxBins = strategy.maxBins val numBins = if (maxBins <= count) maxBins else count.toInt logDebug("numBins = " + numBins) + val isMulticlassClassification = strategy.isMulticlassClassification + logDebug("isMulticlassClassification = " + isMulticlassClassification) + /* * Ensure #bins is always greater than the categories. For multiclass classification, @@ -1187,7 +1195,7 @@ object DecisionTree extends Serializable with Logging { if (strategy.categoricalFeaturesInfo.size > 0) { val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2 require(numBins > maxCategoriesForFeatures) - if (strategy.isMultiClassification) { + if (isMulticlassClassification) { require(numBins > math.pow(2, maxCategoriesForFeatures.toInt - 1) - 1) } } @@ -1230,7 +1238,7 @@ object DecisionTree extends Serializable with Logging { val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) // Use different bin/split calculation strategy for multiclass classification - if (strategy.isMultiClassification) { + if (isMulticlassClassification) { // 2^(maxFeatureValue- 1) - 1 combinations var index = 0 while (index < math.pow(2.0, featureCategories - 1).toInt - 1) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 89daaaeccdca..e51f0f726183 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -58,6 +58,6 @@ class Strategy ( val labelWeights: Map[Int, Int] = Map[Int, Int]()) extends Serializable { require(numClassesForClassification >= 2) - val isMultiClassification = numClassesForClassification > 2 + val isMulticlassClassification = numClassesForClassification > 2 } From 828ff169778e03b2a19ba85ffbd6619c8554ae25 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Tue, 20 May 2014 17:27:30 -0700 Subject: [PATCH 43/72] added categorical variable test --- .../spark/mllib/tree/DecisionTreeSuite.scala | 49 ++++++++++++++++++- 1 file changed, 47 insertions(+), 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index e7a55d52e736..664abf742d4a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -133,7 +133,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { maxDepth = 3, numClassesForClassification = 2, maxBins = 100, - categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) // Check splits. @@ -483,7 +483,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits(0)._2.predict === 1) } - test("test second level node building with/without groups") { + test("second level node building with/without groups") { val arr = DecisionTreeSuite.generateOrderedLabeledPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -529,6 +529,33 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { } + test("stump with continuous variables for multiclass classification") { + assert(true==true) + } + + test("stump with categorical variables for multiclass classification") { + val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) + assert(strategy.isMulticlassClassification) + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + Array[List[Filter]](), splits, bins, 10) + + assert(bestSplits.length === 1) + val bestSplit = bestSplits(0)._1 + assert(bestSplit.feature === 0) + assert(bestSplit.categories.length === 1) + assert(bestSplit.categories.contains(0)) + assert(bestSplit.featureType === Categorical) + println(bestSplit) + } + + test("stump with continuous + categorical variables for multiclass classification") { + assert(true==true) + } + } object DecisionTreeSuite { @@ -576,4 +603,22 @@ object DecisionTreeSuite { } arr } + + def generateCategoricalDataPointsForMulticlass(): Array[WeightedLabeledPoint] = { + val arr = new Array[WeightedLabeledPoint](3000) + for (i <- 0 until 3000) { + if (i < 1000) { + arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, 2.0)) + } else if (i < 2000) { + arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(1.0, 2.0)) + } else { + arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, 2.0)) + } + } + println(arr(0)) + println(arr(1000)) + println(arr(2000)) + arr + } + } From 8cfd3b6405891334a89f834c8c1fedbb3eb0868a Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Wed, 21 May 2014 18:15:19 -0700 Subject: [PATCH 44/72] working for categorical multiclass classification --- .../spark/mllib/tree/DecisionTree.scala | 364 +++++++++++------- .../spark/mllib/tree/DecisionTreeSuite.scala | 43 ++- 2 files changed, 266 insertions(+), 141 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 861b35124368..c6a306d43633 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -647,9 +647,9 @@ object DecisionTree extends Serializable with Logging { * numClasses * numSplits * numFeatures*numNodes for classification * @param arr Array[Double] of size 1 + (numFeatures * numNodes) * @return Array[Double] storing aggregate calculation of size - * numClasses * numSplits * numFeatures * numNodes for classification + * 2 * numSplits * numFeatures * numNodes for classification */ - def classificationBinSeqOp(arr: Array[Double], agg: Array[Double]) { + def binaryClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) { // Iterate over all nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { @@ -666,27 +666,11 @@ object DecisionTree extends Serializable with Logging { val arrShift = 1 + numFeatures * nodeIndex val arrIndex = arrShift + featureIndex // Update the left or right count for one bin. - val aggShift = numClasses * numBins * numFeatures * nodeIndex + val aggShift = 2 * numBins * numFeatures * nodeIndex val aggIndex - = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses + = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2 label.toInt match { - case n: Int => - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (!isFeatureContinuous && isMulticlassClassification) { - // Find all matching bins and increment their values - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 - var binIndex = 0 - while (binIndex < numCategoricalBins) { - if (bins(featureIndex)(binIndex).highSplit.categories.contains(n)){ - agg(aggIndex + binIndex) - = agg(aggIndex + binIndex) + labelWeights.getOrElse(binIndex, 1) - } - binIndex += 1 - } - } else { - agg(aggIndex + n) = agg(aggIndex + n) + labelWeights.getOrElse(n, 1) - } + case n: Int => agg(aggIndex + n) = agg(aggIndex + n) + labelWeights.getOrElse(n, 1) } featureIndex += 1 } @@ -695,6 +679,77 @@ object DecisionTree extends Serializable with Logging { } } + /** + * Performs a sequential aggregation over a partition for classification. For l nodes, + * k features, either the left count or the right count of one of the p bins is + * incremented based upon whether the feature is classified as 0 or 1. + * + * @param agg Array[Double] storing aggregate calculation of size + * numClasses * numSplits * numFeatures*numNodes for classification + * @param arr Array[Double] of size 1 + (numFeatures * numNodes) + * @return Array[Double] storing aggregate calculation of size + * 2 * numClasses * numSplits * numFeatures * numNodes for classification + */ + def multiClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) { + // Iterate over all nodes. + var nodeIndex = 0 + while (nodeIndex < numNodes) { + // Check whether the instance was valid for this nodeIndex. + val validSignalIndex = 1 + numFeatures * nodeIndex + val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex + if (isSampleValidForNode) { + val rightChildShift = numClasses * numBins * numFeatures * numNodes + // actual class label + val label = arr(0) + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + val isContinuousFeature = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isContinuousFeature) { + // Find the bin index for this feature. + val arrShift = 1 + numFeatures * nodeIndex + val arrIndex = arrShift + featureIndex + // Update the left or right count for one bin. + val aggShift = numClasses * numBins * numFeatures * nodeIndex + val aggIndex + = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses + label.toInt match { + case n: Int => agg(aggIndex + n) = agg(aggIndex + n) + labelWeights.getOrElse(n, 1) + } + } else { + // Find the bin index for this feature. + val arrShift = 1 + numFeatures * nodeIndex + val arrIndex = arrShift + featureIndex + // Update the left or right count for one bin. + val aggShift = numClasses * numBins * numFeatures * nodeIndex + val aggIndex + = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses + label.toInt match { + case n: Int => + // Find all matching bins and increment their values + val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) + val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 + var binIndex = 0 + while (binIndex < numCategoricalBins) { + if (bins(featureIndex)(binIndex).highSplit.categories.contains(n)) { + agg(aggIndex + binIndex) + = agg(aggIndex + binIndex) + labelWeights.getOrElse(n, 1) + } else { + agg(rightChildShift + aggIndex + binIndex) + = agg(rightChildShift + aggIndex + binIndex) + labelWeights.getOrElse(n, 1) + + } + binIndex += 1 + } + } + } + featureIndex += 1 + } + } + nodeIndex += 1 + } + } + /** * Performs a sequential aggregation over a partition for regression. For l nodes, k features, * the count, sum, sum of squares of one of the p bins is incremented. @@ -739,7 +794,12 @@ object DecisionTree extends Serializable with Logging { */ def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = { strategy.algo match { - case Classification => classificationBinSeqOp(arr, agg) + case Classification => + if(isMulticlassClassification) { + multiClassificationBinSeqOp(arr, agg) + } else { + binaryClassificationBinSeqOp(arr, agg) + } case Regression => regressionBinSeqOp(arr, agg) } agg @@ -747,7 +807,12 @@ object DecisionTree extends Serializable with Logging { // Calculate bin aggregate length for classification or regression. val binAggregateLength = strategy.algo match { - case Classification => numClasses * numBins * numFeatures * numNodes + case Classification => + if (isMulticlassClassification){ + 2 * numClasses * numBins * numFeatures * numNodes + } else { + 2 * numBins * numFeatures * numNodes + } case Regression => 3 * numBins * numFeatures * numNodes } logDebug("binAggregateLength = " + binAggregateLength) @@ -920,80 +985,139 @@ object DecisionTree extends Serializable with Logging { */ def extractLeftRightNodeAggregates( binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = { + + + def findAggForOrderedFeature( + leftNodeAgg: Array[Array[Array[Double]]], + rightNodeAgg: Array[Array[Array[Double]]], + featureIndex: Int) { + + // shift for this featureIndex + val shift = 2 * featureIndex * numBins + + // left node aggregate for the lowest split + leftNodeAgg(featureIndex)(0)(0) = binData(shift + 0) + leftNodeAgg(featureIndex)(0)(1) = binData(shift + 1) + + // right node aggregate for the highest split + rightNodeAgg(featureIndex)(numBins - 2)(0) + = binData(shift + (2 * (numBins - 1))) + rightNodeAgg(featureIndex)(numBins - 2)(1) + = binData(shift + (2 * (numBins - 1)) + 1) + + // Iterate over all splits. + var splitIndex = 1 + while (splitIndex < numBins - 1) { + // calculating left node aggregate for a split as a sum of left node aggregate of a + // lower split and the left bin aggregate of a bin where the split is a high split + leftNodeAgg(featureIndex)(splitIndex)(0) = binData(shift + 2 * splitIndex) + + leftNodeAgg(featureIndex)(splitIndex - 1)(0) + leftNodeAgg(featureIndex)(splitIndex)(1) = binData(shift + 2 * splitIndex + + 1) + leftNodeAgg(featureIndex)(splitIndex - 1)(1) + + // calculating right node aggregate for a split as a sum of right node aggregate of a + // higher split and the right bin aggregate of a bin where the split is a low split + rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(0) = + binData(shift + (2 * (numBins - 2 - splitIndex))) + + rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(0) + rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(1) = + binData(shift + (2 * (numBins - 2 - splitIndex) + 1)) + + rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(1) + + splitIndex += 1 + } + } + + def extractAggForCategoricalFeature( + leftNodeAgg: Array[Array[Array[Double]]], + rightNodeAgg: Array[Array[Array[Double]]], + featureIndex: Int) { + + val rightChildShift = numClasses * numBins * numFeatures + var splitIndex = 0 + while (splitIndex < numBins - 1) { + var classIndex = 0 + while (classIndex < numClasses) { + // shift for this featureIndex + val shift = numClasses * featureIndex * numBins + splitIndex * numClasses + val leftBinValue = binData(shift + classIndex) + val rightBinValue = binData(rightChildShift + shift + classIndex) + leftNodeAgg(featureIndex)(splitIndex)(classIndex) = leftBinValue + rightNodeAgg(featureIndex)(splitIndex)(classIndex) = rightBinValue + classIndex += 1 + } + splitIndex += 1 + } + } + + def findAggForRegression( + leftNodeAgg: Array[Array[Array[Double]]], + rightNodeAgg: Array[Array[Array[Double]]], + featureIndex: Int) { + + // shift for this featureIndex + val shift = 3 * featureIndex * numBins + // left node aggregate for the lowest split + leftNodeAgg(featureIndex)(0)(0) = binData(shift + 0) + leftNodeAgg(featureIndex)(0)(1) = binData(shift + 1) + leftNodeAgg(featureIndex)(0)(2) = binData(shift + 2) + + // right node aggregate for the highest split + rightNodeAgg(featureIndex)(numBins - 2)(0) = + binData(shift + (3 * (numBins - 1))) + rightNodeAgg(featureIndex)(numBins - 2)(1) = + binData(shift + (3 * (numBins - 1)) + 1) + rightNodeAgg(featureIndex)(numBins - 2)(2) = + binData(shift + (3 * (numBins - 1)) + 2) + + // Iterate over all splits. + var splitIndex = 1 + while (splitIndex < numBins - 1) { + // calculating left node aggregate for a split as a sum of left node aggregate of a + // lower split and the left bin aggregate of a bin where the split is a high split + leftNodeAgg(featureIndex)(splitIndex)(0) = binData(shift + 3 * splitIndex) + + leftNodeAgg(featureIndex)(splitIndex - 1)(0) + leftNodeAgg(featureIndex)(splitIndex)(1) = binData(shift + 3 * splitIndex + 1) + + leftNodeAgg(featureIndex)(splitIndex - 1)(1) + leftNodeAgg(featureIndex)(splitIndex)(2) = binData(shift + 3 * splitIndex + 2) + + leftNodeAgg(featureIndex)(splitIndex - 1)(2) + + // calculating right node aggregate for a split as a sum of right node aggregate of a + // higher split and the right bin aggregate of a bin where the split is a low split + rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(0) = + binData(shift + (3 * (numBins - 2 - splitIndex))) + + rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(0) + rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(1) = + binData(shift + (3 * (numBins - 2 - splitIndex) + 1)) + + rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(1) + rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(2) = + binData(shift + (3 * (numBins - 2 - splitIndex) + 2)) + + rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(2) + + splitIndex += 1 + } + } + strategy.algo match { case Classification => - // Initialize left and right split aggregates. val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) - - if (isMulticlassClassification) { - var featureIndex = 0 - while (featureIndex < numFeatures){ - var splitIndex = 0 - while (splitIndex < numBins - 1) { - val totalNodeAgg = Array.ofDim[Double](numClasses) - var classIndex = 0 - while (classIndex < numClasses) { - // shift for this featureIndex - val shift = numClasses * featureIndex * numBins - val binValue = binData(shift + classIndex) - leftNodeAgg(featureIndex)(splitIndex)(classIndex) = binValue - totalNodeAgg(classIndex) = binValue - classIndex += 1 - } - // Calculate rightNodeAgg - classIndex = 0 - while (classIndex < numClasses) { - rightNodeAgg(featureIndex)(splitIndex)(classIndex) - = totalNodeAgg(classIndex) - leftNodeAgg(featureIndex)(splitIndex)(classIndex) - classIndex += 1 - } - splitIndex += 1 - } - featureIndex += 1 - } - } else { - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - // shift for this featureIndex - val shift = 2 * featureIndex * numBins - - // left node aggregate for the lowest split - leftNodeAgg(featureIndex)(0)(0) = binData(shift + 0) - leftNodeAgg(featureIndex)(0)(1) = binData(shift + 1) - - // right node aggregate for the highest split - rightNodeAgg(featureIndex)(numBins - 2)(0) - = binData(shift + (2 * (numBins - 1))) - rightNodeAgg(featureIndex)(numBins - 2)(1) - = binData(shift + (2 * (numBins - 1)) + 1) - - // Iterate over all splits. - var splitIndex = 1 - while (splitIndex < numBins - 1) { - // calculating left node aggregate for a split as a sum of left node aggregate of a - // lower split and the left bin aggregate of a bin where the split is a high split - leftNodeAgg(featureIndex)(splitIndex)(0) = binData(shift + 2 * splitIndex) + - leftNodeAgg(featureIndex)(splitIndex - 1)(0) - leftNodeAgg(featureIndex)(splitIndex)(1) = binData(shift + 2 * splitIndex + - 1) + leftNodeAgg(featureIndex)(splitIndex - 1)(1) - - // calculating right node aggregate for a split as a sum of right node aggregate of a - // higher split and the right bin aggregate of a bin where the split is a low split - rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(0) = - binData(shift + (2 *(numBins - 2 - splitIndex))) + - rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(0) - rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(1) = - binData(shift + (2* (numBins - 2 - splitIndex) + 1)) + - rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(1) - - splitIndex += 1 + var featureIndex = 0 + while (featureIndex < numFeatures) { + if (isMulticlassClassification){ + val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinuous) { + findAggForOrderedFeature(leftNodeAgg, rightNodeAgg, featureIndex) + } else { + extractAggForCategoricalFeature(leftNodeAgg, rightNodeAgg, featureIndex) } - featureIndex += 1 + } else { + findAggForOrderedFeature(leftNodeAgg, rightNodeAgg, featureIndex) } + featureIndex += 1 } + (leftNodeAgg, rightNodeAgg) case Regression => // Initialize left and right split aggregates. @@ -1002,47 +1126,7 @@ object DecisionTree extends Serializable with Logging { // Iterate over all features. var featureIndex = 0 while (featureIndex < numFeatures) { - // shift for this featureIndex - val shift = 3 * featureIndex * numBins - // left node aggregate for the lowest split - leftNodeAgg(featureIndex)(0)(0) = binData(shift + 0) - leftNodeAgg(featureIndex)(0)(1) = binData(shift + 1) - leftNodeAgg(featureIndex)(0)(2) = binData(shift + 2) - - // right node aggregate for the highest split - rightNodeAgg(featureIndex)(numBins - 2)(0) = - binData(shift + (3 * (numBins - 1))) - rightNodeAgg(featureIndex)(numBins - 2)(1) = - binData(shift + (3 * (numBins - 1)) + 1) - rightNodeAgg(featureIndex)(numBins - 2)(2) = - binData(shift + (3 * (numBins - 1)) + 2) - - // Iterate over all splits. - var splitIndex = 1 - while (splitIndex < numBins - 1) { - // calculating left node aggregate for a split as a sum of left node aggregate of a - // lower split and the left bin aggregate of a bin where the split is a high split - leftNodeAgg(featureIndex)(splitIndex)(0) = binData(shift + 3 * splitIndex) + - leftNodeAgg(featureIndex)(splitIndex - 1)(0) - leftNodeAgg(featureIndex)(splitIndex)(1) = binData(shift + 3 * splitIndex + 1) + - leftNodeAgg(featureIndex)(splitIndex - 1)(1) - leftNodeAgg(featureIndex)(splitIndex)(2) = binData(shift + 3 * splitIndex + 2) + - leftNodeAgg(featureIndex)(splitIndex - 1)(2) - - // calculating right node aggregate for a split as a sum of right node aggregate of a - // higher split and the right bin aggregate of a bin where the split is a low split - rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(0) = - binData(shift + (3 * (numBins - 2 - splitIndex))) + - rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(0) - rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(1) = - binData(shift + (3 * (numBins - 2 - splitIndex) + 1)) + - rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(1) - rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(2) = - binData(shift + (3 * (numBins - 2 - splitIndex) + 2)) + - rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(2) - - splitIndex += 1 - } + findAggForRegression(leftNodeAgg, rightNodeAgg, featureIndex) featureIndex += 1 } (leftNodeAgg, rightNodeAgg) @@ -1134,9 +1218,23 @@ object DecisionTree extends Serializable with Logging { def getBinDataForNode(node: Int): Array[Double] = { strategy.algo match { case Classification => - val shift = numClasses * node * numBins * numFeatures - val binsForNode = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures) - binsForNode + if (isMulticlassClassification) { + val shift = numClasses * node * numBins * numFeatures + val rightChildShift = numClasses * numBins * numFeatures * numNodes + val binsForNode = { + val leftChildData + = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures) + val rightChildData + = binAggregates.slice(rightChildShift + shift, + rightChildShift + shift + numClasses * numBins * numFeatures) + leftChildData ++ rightChildData + } + binsForNode + } else { + val shift = numClasses * node * numBins * numFeatures + val binsForNode = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures) + binsForNode + } case Regression => val shift = 3 * node * numBins * numFeatures val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 664abf742d4a..41cf5a120bac 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -529,10 +529,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { } - test("stump with continuous variables for multiclass classification") { - assert(true==true) - } - test("stump with categorical variables for multiclass classification") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() val input = sc.parallelize(arr) @@ -547,11 +543,32 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val bestSplit = bestSplits(0)._1 assert(bestSplit.feature === 0) assert(bestSplit.categories.length === 1) - assert(bestSplit.categories.contains(0)) + assert(bestSplit.categories.contains(1)) assert(bestSplit.featureType === Categorical) + } + + + test("stump with continuous variables for multiclass classification") { + val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + numClassesForClassification = 3) + assert(strategy.isMulticlassClassification) + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + Array[List[Filter]](), splits, bins, 10) + + assert(bestSplits.length === 1) + val bestSplit = bestSplits(0)._1 + + //assert(bestSplit.feature === 1) + //assert(bestSplit.featureType == Continuous) + //assert(bestSplit.threshold > 1000) println(bestSplit) + } + test("stump with continuous + categorical variables for multiclass classification") { assert(true==true) } @@ -615,10 +632,20 @@ object DecisionTreeSuite { arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, 2.0)) } } - println(arr(0)) - println(arr(1000)) - println(arr(2000)) arr } + def generateContinuousDataPointsForMulticlass(): Array[WeightedLabeledPoint] = { + val arr = new Array[WeightedLabeledPoint](3000) + for (i <- 0 until 3000) { + if (i < 2000) { + arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, 100)) + } else { + arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(2.0, 3000)) + } + } + arr + } + + } From f5f6b833d62d7fba982c62971dc373c70363385e Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Wed, 21 May 2014 18:40:24 -0700 Subject: [PATCH 45/72] multiclass for continous variables --- .../spark/mllib/tree/DecisionTree.scala | 56 +++++++++---------- .../spark/mllib/tree/DecisionTreeSuite.scala | 12 ++-- 2 files changed, 32 insertions(+), 36 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index c6a306d43633..82fd719c7599 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -987,48 +987,44 @@ object DecisionTree extends Serializable with Logging { binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = { - def findAggForOrderedFeature( + def findAggForOrderedFeatureClassification( leftNodeAgg: Array[Array[Array[Double]]], rightNodeAgg: Array[Array[Array[Double]]], featureIndex: Int) { // shift for this featureIndex - val shift = 2 * featureIndex * numBins - - // left node aggregate for the lowest split - leftNodeAgg(featureIndex)(0)(0) = binData(shift + 0) - leftNodeAgg(featureIndex)(0)(1) = binData(shift + 1) - - // right node aggregate for the highest split - rightNodeAgg(featureIndex)(numBins - 2)(0) - = binData(shift + (2 * (numBins - 1))) - rightNodeAgg(featureIndex)(numBins - 2)(1) - = binData(shift + (2 * (numBins - 1)) + 1) + val shift = numClasses * featureIndex * numBins + + var classIndex = 0 + while (classIndex < numClasses) { + // left node aggregate for the lowest split + leftNodeAgg(featureIndex)(0)(classIndex) = binData(shift + classIndex) + // right node aggregate for the highest split + rightNodeAgg(featureIndex)(numBins - 2)(classIndex) + = binData(shift + (numClasses * (numBins - 1)) + classIndex) + classIndex += 1 + } // Iterate over all splits. var splitIndex = 1 while (splitIndex < numBins - 1) { // calculating left node aggregate for a split as a sum of left node aggregate of a // lower split and the left bin aggregate of a bin where the split is a high split - leftNodeAgg(featureIndex)(splitIndex)(0) = binData(shift + 2 * splitIndex) + - leftNodeAgg(featureIndex)(splitIndex - 1)(0) - leftNodeAgg(featureIndex)(splitIndex)(1) = binData(shift + 2 * splitIndex + - 1) + leftNodeAgg(featureIndex)(splitIndex - 1)(1) - - // calculating right node aggregate for a split as a sum of right node aggregate of a - // higher split and the right bin aggregate of a bin where the split is a low split - rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(0) = - binData(shift + (2 * (numBins - 2 - splitIndex))) + - rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(0) - rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(1) = - binData(shift + (2 * (numBins - 2 - splitIndex) + 1)) + - rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(1) - + var innerClassIndex = 0 + while (innerClassIndex < numClasses) { + leftNodeAgg(featureIndex)(splitIndex)(innerClassIndex) + = binData(shift + numClasses * splitIndex + innerClassIndex) + + leftNodeAgg(featureIndex)(splitIndex - 1)(innerClassIndex) + rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(innerClassIndex) = + binData(shift + (numClasses * (numBins - 2 - splitIndex) + innerClassIndex)) + + rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(innerClassIndex) + innerClassIndex += 1 + } splitIndex += 1 } } - def extractAggForCategoricalFeature( + def findAggregateForCategoricalFeatureClassification( leftNodeAgg: Array[Array[Array[Double]]], rightNodeAgg: Array[Array[Array[Double]]], featureIndex: Int) { @@ -1108,12 +1104,12 @@ object DecisionTree extends Serializable with Logging { if (isMulticlassClassification){ val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { - findAggForOrderedFeature(leftNodeAgg, rightNodeAgg, featureIndex) + findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) } else { - extractAggForCategoricalFeature(leftNodeAgg, rightNodeAgg, featureIndex) + findAggregateForCategoricalFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) } } else { - findAggForOrderedFeature(leftNodeAgg, rightNodeAgg, featureIndex) + findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) } featureIndex += 1 } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 41cf5a120bac..12477c1fc1b0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -561,10 +561,10 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits.length === 1) val bestSplit = bestSplits(0)._1 - //assert(bestSplit.feature === 1) - //assert(bestSplit.featureType == Continuous) - //assert(bestSplit.threshold > 1000) - println(bestSplit) + assert(bestSplit.feature === 1) + assert(bestSplit.featureType === Continuous) + assert(bestSplit.threshold > 1980) + assert(bestSplit.threshold < 2020) } @@ -639,9 +639,9 @@ object DecisionTreeSuite { val arr = new Array[WeightedLabeledPoint](3000) for (i <- 0 until 3000) { if (i < 2000) { - arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, 100)) + arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, i)) } else { - arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(2.0, 3000)) + arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(2.0, i)) } } arr From 1892a2cfc1b200c828e3e3efbe3e888988172dcc Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Fri, 23 May 2014 10:27:01 -0700 Subject: [PATCH 46/72] tests and use multiclass binaggregate length when atleast one categorical feature is present --- docs/mllib-decision-tree.md | 6 +- .../spark/mllib/tree/DecisionTree.scala | 234 +++++++++--------- .../mllib/tree/configuration/Strategy.scala | 2 + .../spark/mllib/tree/DecisionTreeSuite.scala | 17 +- 4 files changed, 139 insertions(+), 120 deletions(-) diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index acf0feff42a8..16e5d10ecbe7 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -76,8 +76,8 @@ bins if the condition is not satisfied. **Categorical features** -For `$M$` categorical features, one could come up with `$2^M-1$` split candidates. However, for -binary classification, the number of split candidates can be reduced to `$M-1$` by ordering the +For `$M$` categorical features, one could come up with `$2^(M-1)-1$` split candidates. For +binary classification, we can reduce the number of split candidates to `$M-1$` by ordering the categorical feature values by the proportion of labels falling in one of the two classes (see Section 9.2.4 in [Elements of Statistical Machine Learning](http://statweb.stanford.edu/~tibs/ElemStatLearn/) for @@ -85,7 +85,7 @@ details). For example, for a binary classification problem with one categorical categories A, B and C with corresponding proportion of label 1 as 0.2, 0.6 and 0.4, the categorical features are ordered as A followed by C followed B or A, B, C. The two split candidates are A \| C, B and A , B \| C where \| denotes the split. - + ### Stopping rule The recursive tree construction is stopped at a node when one of the two conditions is met: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 82fd719c7599..12eb09ffc39c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -78,11 +78,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Max memory usage for aggregates val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024 logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") - val numElementsPerNode = - strategy.algo match { - case Classification => 2 * numBins * numFeatures - case Regression => 3 * numBins * numFeatures - } + val numElementsPerNode = DecisionTree.getElementsPerNode(numFeatures, numBins, + strategy.numClassesForClassification, strategy.isMulticlassWithCategoricalFeatures, + strategy.algo) logDebug("numElementsPerNode = " + numElementsPerNode) val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array @@ -144,7 +142,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo new DecisionTreeModel(topNode, strategy.algo) } - // TODO: Unit test this /** * Extract the decision tree node information for the given tree level and node index */ @@ -162,7 +159,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo nodes(nodeIndex) = node } - // TODO: Unit test this /** * Extract the decision tree node information for the children of the node */ @@ -290,12 +286,12 @@ object DecisionTree extends Serializable with Logging { * @return a DecisionTreeModel that can be used for prediction */ def train( - input: RDD[LabeledPoint], - algo: Algo, - impurity: Impurity, - maxDepth: Int, - numClassesForClassification: Int, - labelWeights: Map[Int,Int]): DecisionTreeModel = { + input: RDD[LabeledPoint], + algo: Algo, + impurity: Impurity, + maxDepth: Int, + numClassesForClassification: Int, + labelWeights: Map[Int,Int]): DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, labelWeights = labelWeights) @@ -462,7 +458,9 @@ object DecisionTree extends Serializable with Logging { logDebug("labelWeights = " + labelWeights) val isMulticlassClassification = strategy.isMulticlassClassification logDebug("isMulticlassClassification = " + isMulticlassClassification) - + val isMulticlassClassificationWithCategoricalFeatures + = strategy.isMulticlassWithCategoricalFeatures + logDebug("isMultiClassWithCategoricalFeatures = " + isMulticlassClassificationWithCategoricalFeatures) // shift when more than one group is used at deep tree level val groupShift = numNodes * groupIndex @@ -518,9 +516,7 @@ object DecisionTree extends Serializable with Logging { /** * Find bin for one feature. */ - def findBin( - featureIndex: Int, - labeledPoint: WeightedLabeledPoint, + def findBin(featureIndex: Int, labeledPoint: WeightedLabeledPoint, isFeatureContinuous: Boolean): Int = { val binForFeatures = bins(featureIndex) val feature = labeledPoint.features(featureIndex) @@ -636,9 +632,48 @@ object DecisionTree extends Serializable with Logging { } // Find feature bins for all nodes at a level. - val binMappedRDD = input.map(x => findBinsForLevel(x)) + val binMappedRDD = input.map(x => findBinsForLevel(x)) + + def updateBinForOrderedFeature(arr: Array[Double], agg: Array[Double], nodeIndex: Int, + label: Double, featureIndex: Int) = { + + // Find the bin index for this feature. + val arrShift = 1 + numFeatures * nodeIndex + val arrIndex = arrShift + featureIndex + // Update the left or right count for one bin. + val aggShift = numClasses * numBins * numFeatures * nodeIndex + val aggIndex = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses + val labelInt = label.toInt + agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + labelWeights.getOrElse(labelInt, 1) + } - /** + def updateBinForUnorderedFeature(nodeIndex: Int, featureIndex: Int, arr: Array[Double], + label: Double, agg: Array[Double], rightChildShift: Int) = { + // Find the bin index for this feature. + val arrShift = 1 + numFeatures * nodeIndex + val arrIndex = arrShift + featureIndex + // Update the left or right count for one bin. + val aggShift = numClasses * numBins * numFeatures * nodeIndex + val aggIndex + = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses + // Find all matching bins and increment their values + val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) + val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 + var binIndex = 0 + while (binIndex < numCategoricalBins) { + val labelInt = label.toInt + if (bins(featureIndex)(binIndex).highSplit.categories.contains(labelInt)) { + agg(aggIndex + binIndex) + = agg(aggIndex + binIndex) + labelWeights.getOrElse(labelInt, 1) + } else { + agg(rightChildShift + aggIndex + binIndex) + = agg(rightChildShift + aggIndex + binIndex) + labelWeights.getOrElse(labelInt, 1) + } + binIndex += 1 + } + } + + /** * Performs a sequential aggregation over a partition for classification. For l nodes, * k features, either the left count or the right count of one of the p bins is * incremented based upon whether the feature is classified as 0 or 1. @@ -649,7 +684,7 @@ object DecisionTree extends Serializable with Logging { * @return Array[Double] storing aggregate calculation of size * 2 * numSplits * numFeatures * numNodes for classification */ - def binaryClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) { + def binaryClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = { // Iterate over all nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { @@ -662,16 +697,7 @@ object DecisionTree extends Serializable with Logging { // Iterate over all features. var featureIndex = 0 while (featureIndex < numFeatures) { - // Find the bin index for this feature. - val arrShift = 1 + numFeatures * nodeIndex - val arrIndex = arrShift + featureIndex - // Update the left or right count for one bin. - val aggShift = 2 * numBins * numFeatures * nodeIndex - val aggIndex - = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2 - label.toInt match { - case n: Int => agg(aggIndex + n) = agg(aggIndex + n) + labelWeights.getOrElse(n, 1) - } + updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex) featureIndex += 1 } } @@ -679,76 +705,43 @@ object DecisionTree extends Serializable with Logging { } } - /** - * Performs a sequential aggregation over a partition for classification. For l nodes, - * k features, either the left count or the right count of one of the p bins is - * incremented based upon whether the feature is classified as 0 or 1. - * - * @param agg Array[Double] storing aggregate calculation of size - * numClasses * numSplits * numFeatures*numNodes for classification - * @param arr Array[Double] of size 1 + (numFeatures * numNodes) - * @return Array[Double] storing aggregate calculation of size - * 2 * numClasses * numSplits * numFeatures * numNodes for classification - */ - def multiClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) { - // Iterate over all nodes. - var nodeIndex = 0 - while (nodeIndex < numNodes) { - // Check whether the instance was valid for this nodeIndex. - val validSignalIndex = 1 + numFeatures * nodeIndex - val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex - if (isSampleValidForNode) { - val rightChildShift = numClasses * numBins * numFeatures * numNodes - // actual class label - val label = arr(0) - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - val isContinuousFeature = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (isContinuousFeature) { - // Find the bin index for this feature. - val arrShift = 1 + numFeatures * nodeIndex - val arrIndex = arrShift + featureIndex - // Update the left or right count for one bin. - val aggShift = numClasses * numBins * numFeatures * nodeIndex - val aggIndex - = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses - label.toInt match { - case n: Int => agg(aggIndex + n) = agg(aggIndex + n) + labelWeights.getOrElse(n, 1) - } - } else { - // Find the bin index for this feature. - val arrShift = 1 + numFeatures * nodeIndex - val arrIndex = arrShift + featureIndex - // Update the left or right count for one bin. - val aggShift = numClasses * numBins * numFeatures * nodeIndex - val aggIndex - = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses - label.toInt match { - case n: Int => - // Find all matching bins and increment their values - val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 - var binIndex = 0 - while (binIndex < numCategoricalBins) { - if (bins(featureIndex)(binIndex).highSplit.categories.contains(n)) { - agg(aggIndex + binIndex) - = agg(aggIndex + binIndex) + labelWeights.getOrElse(n, 1) - } else { - agg(rightChildShift + aggIndex + binIndex) - = agg(rightChildShift + aggIndex + binIndex) + labelWeights.getOrElse(n, 1) - - } - binIndex += 1 - } - } - } - featureIndex += 1 - } + /** + * Performs a sequential aggregation over a partition for classification. For l nodes, + * k features, either the left count or the right count of one of the p bins is + * incremented based upon whether the feature is classified as 0 or 1. + * + * @param agg Array[Double] storing aggregate calculation of size + * numClasses * numSplits * numFeatures*numNodes for classification + * @param arr Array[Double] of size 1 + (numFeatures * numNodes) + * @return Array[Double] storing aggregate calculation of size + * 2 * numClasses * numSplits * numFeatures * numNodes for classification + */ + def multiClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = { + // Iterate over all nodes. + var nodeIndex = 0 + while (nodeIndex < numNodes) { + // Check whether the instance was valid for this nodeIndex. + val validSignalIndex = 1 + numFeatures * nodeIndex + val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex + if (isSampleValidForNode) { + val rightChildShift = numClasses * numBins * numFeatures * numNodes + // actual class label + val label = arr(0) + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + val isContinuousFeature = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isContinuousFeature) { + updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex) + } else { + updateBinForUnorderedFeature(nodeIndex, featureIndex, arr, label, agg, rightChildShift) + } + featureIndex += 1 } - nodeIndex += 1 } + nodeIndex += 1 } + } /** * Performs a sequential aggregation over a partition for regression. For l nodes, k features, @@ -760,7 +753,7 @@ object DecisionTree extends Serializable with Logging { * @return Array[Double] storing aggregate calculation of size * 3 * numSplits * numFeatures * numNodes for regression */ - def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) { + def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) = { // Iterate over all nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { @@ -795,7 +788,7 @@ object DecisionTree extends Serializable with Logging { def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = { strategy.algo match { case Classification => - if(isMulticlassClassification) { + if(isMulticlassClassificationWithCategoricalFeatures) { multiClassificationBinSeqOp(arr, agg) } else { binaryClassificationBinSeqOp(arr, agg) @@ -806,15 +799,8 @@ object DecisionTree extends Serializable with Logging { } // Calculate bin aggregate length for classification or regression. - val binAggregateLength = strategy.algo match { - case Classification => - if (isMulticlassClassification){ - 2 * numClasses * numBins * numFeatures * numNodes - } else { - 2 * numBins * numFeatures * numNodes - } - case Regression => 3 * numBins * numFeatures * numNodes - } + val binAggregateLength = numNodes * getElementsPerNode(numFeatures, numBins, numClasses, + isMulticlassClassificationWithCategoricalFeatures, strategy.algo) logDebug("binAggregateLength = " + binAggregateLength) /** @@ -1024,7 +1010,7 @@ object DecisionTree extends Serializable with Logging { } } - def findAggregateForCategoricalFeatureClassification( + def findAggForUnorderedFeatureClassification( leftNodeAgg: Array[Array[Array[Double]]], rightNodeAgg: Array[Array[Array[Double]]], featureIndex: Int) { @@ -1101,12 +1087,12 @@ object DecisionTree extends Serializable with Logging { val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) var featureIndex = 0 while (featureIndex < numFeatures) { - if (isMulticlassClassification){ + if (isMulticlassClassificationWithCategoricalFeatures){ val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty if (isFeatureContinuous) { findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) } else { - findAggregateForCategoricalFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) + findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) } } else { findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) @@ -1214,7 +1200,7 @@ object DecisionTree extends Serializable with Logging { def getBinDataForNode(node: Int): Array[Double] = { strategy.algo match { case Classification => - if (isMulticlassClassification) { + if (isMulticlassClassificationWithCategoricalFeatures) { val shift = numClasses * node * numBins * numFeatures val rightChildShift = numClasses * numBins * numFeatures * numNodes val binsForNode = { @@ -1251,10 +1237,22 @@ object DecisionTree extends Serializable with Logging { bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity) node += 1 } - bestSplits } + private def getElementsPerNode(numFeatures: Int, numBins: Int, numClasses: Int, + isMulticlassClassificationWithCategoricalFeatures: Boolean, algo: Algo): Int = { + algo match { + case Classification => + if (isMulticlassClassificationWithCategoricalFeatures) { + 2 * numClasses * numBins * numFeatures + } else { + numClasses * numBins * numFeatures + } + case Regression => 3 * numBins * numFeatures + } + } + /** * Returns split and bins for decision tree calculation. * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data @@ -1288,9 +1286,12 @@ object DecisionTree extends Serializable with Logging { */ if (strategy.categoricalFeaturesInfo.size > 0) { val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2 - require(numBins > maxCategoriesForFeatures) + require(numBins > maxCategoriesForFeatures, "numBins should be greater than max categories " + + "in categorical features") if (isMulticlassClassification) { - require(numBins > math.pow(2, maxCategoriesForFeatures.toInt - 1) - 1) + require(numBins > math.pow(2, maxCategoriesForFeatures.toInt - 1) - 1, + "numBins should be greater than 2^(maxNumCategories-1) -1 for multiclass classification" + + " with categorical variables") } } @@ -1331,7 +1332,8 @@ object DecisionTree extends Serializable with Logging { } else { // Categorical feature val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - // Use different bin/split calculation strategy for multiclass classification + // Use different bin/split calculation strategy for categorical features in multiclass + // classification if (isMulticlassClassification) { // 2^(maxFeatureValue- 1) - 1 combinations var index = 0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index e51f0f726183..7aec14d293ec 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -59,5 +59,7 @@ class Strategy ( require(numClassesForClassification >= 2) val isMulticlassClassification = numClassesForClassification > 2 + val isMulticlassWithCategoricalFeatures + = isMulticlassClassification && (categoricalFeaturesInfo.size > 0) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 12477c1fc1b0..c06ad055afee 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -570,7 +570,22 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { test("stump with continuous + categorical variables for multiclass classification") { - assert(true==true) + val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3)) + assert(strategy.isMulticlassClassification) + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + Array[List[Filter]](), splits, bins, 10) + + assert(bestSplits.length === 1) + val bestSplit = bestSplits(0)._1 + + assert(bestSplit.feature === 1) + assert(bestSplit.featureType === Continuous) + assert(bestSplit.threshold > 1980) + assert(bestSplit.threshold < 2020) } } From 12e6d0ab757768f25581bb6a391c2b80791ea30b Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 25 May 2014 17:59:46 -0700 Subject: [PATCH 47/72] minor: removing line in doc --- docs/mllib-decision-tree.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index 6089023fe67f..8f8f54a3d92d 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -86,7 +86,7 @@ details). For example, for a binary classification problem with one categorical categories A, B and C with corresponding proportion of label 1 as 0.2, 0.6 and 0.4, the categorical features are ordered as A followed by C followed B or A, B, C. The two split candidates are A \| C, B and A , B \| C where \| denotes the split. - + ### Stopping rule The recursive tree construction is stopped at a node when one of the two conditions is met: From 237762d3186c2f271e26a9a8bb61899016290312 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 25 May 2014 19:30:47 -0700 Subject: [PATCH 48/72] renaming functions --- .../spark/mllib/tree/DecisionTree.scala | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 12eb09ffc39c..61975fa69c68 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -547,16 +547,18 @@ object DecisionTree extends Serializable with Logging { /** * Sequential search helper method to find bin for categorical feature in multiclass - * classification. Dummy value of 0 used since it is not used in future calculation + * classification. The category is returned since each category can belong to multiple + * splits. The actual left/right child allocation per split is performed in the + * sequential phase of the bin aggregate operation. */ - def sequentialBinSearchForCategoricalFeatureInBinaryClassification(): Int = { + def sequentialBinSearchForCategoricalFeatureInMulticlassClassification(): Int = { labeledPoint.features(featureIndex).toInt } /** * Sequential search helper method to find bin for categorical feature. */ - def sequentialBinSearchForCategoricalFeatureInMultiClassClassification(): Int = { + def sequentialBinSearchForCategoricalFeatureInBinaryClassification(): Int = { val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 var binIndex = 0 @@ -583,9 +585,9 @@ object DecisionTree extends Serializable with Logging { // Perform sequential search to find bin for categorical features. val binIndex = { if (isMulticlassClassification) { - sequentialBinSearchForCategoricalFeatureInBinaryClassification() + sequentialBinSearchForCategoricalFeatureInMulticlassClassification() } else { - sequentialBinSearchForCategoricalFeatureInMultiClassClassification() + sequentialBinSearchForCategoricalFeatureInBinaryClassification() } } if (binIndex == -1){ @@ -684,7 +686,7 @@ object DecisionTree extends Serializable with Logging { * @return Array[Double] storing aggregate calculation of size * 2 * numSplits * numFeatures * numNodes for classification */ - def binaryClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = { + def orderedClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = { // Iterate over all nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { @@ -716,7 +718,7 @@ object DecisionTree extends Serializable with Logging { * @return Array[Double] storing aggregate calculation of size * 2 * numClasses * numSplits * numFeatures * numNodes for classification */ - def multiClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = { + def unorderedClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = { // Iterate over all nodes. var nodeIndex = 0 while (nodeIndex < numNodes) { @@ -789,9 +791,9 @@ object DecisionTree extends Serializable with Logging { strategy.algo match { case Classification => if(isMulticlassClassificationWithCategoricalFeatures) { - multiClassificationBinSeqOp(arr, agg) + unorderedClassificationBinSeqOp(arr, agg) } else { - binaryClassificationBinSeqOp(arr, agg) + orderedClassificationBinSeqOp(arr, agg) } case Regression => regressionBinSeqOp(arr, agg) } From 34ee7b9bc4bbf1859c946456242bb01bbd2c0e09 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 26 May 2014 14:59:09 -0700 Subject: [PATCH 49/72] minor: code style --- .../org/apache/spark/mllib/tree/DecisionTree.scala | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 61975fa69c68..859f4f72957c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -230,10 +230,10 @@ object DecisionTree extends Serializable with Logging { * @return a DecisionTreeModel that can be used for prediction */ def train( - input: RDD[LabeledPoint], - algo: Algo, - impurity: Impurity, - maxDepth: Int): DecisionTreeModel = { + input: RDD[LabeledPoint], + algo: Algo, + impurity: Impurity, + maxDepth: Int): DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth) // Converting from standard instance format to weighted input format for tree training val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features)) @@ -279,10 +279,10 @@ object DecisionTree extends Serializable with Logging { * @param impurity impurity criterion used for information gain calculation * @param maxDepth maxDepth maximum depth of the tree * @param numClassesForClassification number of classes for classification. Default value of 2. - * @param labelWeights A map storing weights applied to each label for handling unbalanced + * @param labelWeights A map storing weights for each label to handle unbalanced * datasets. For example, an entry (n -> k) implies the a weight of k is * applied to an instance with label n. It's important to note that labels - * are zero-index and take values 0, 1, 2, ... , numClasses. + * are zero-index and take values 0, 1, 2, ... , numClasses - 1. * @return a DecisionTreeModel that can be used for prediction */ def train( @@ -316,7 +316,7 @@ object DecisionTree extends Serializable with Logging { * @param labelWeights A map storing weights applied to each label for handling unbalanced * datasets. For example, an entry (n -> k) implies the a weight of k is * applied to an instance with label n. It's important to note that labels - * are zero-index and take values 0, 1, 2, ... , numClasses. + * are zero-index and take values 0, 1, 2, ... , numClasses - 1. * @param maxBins maximum number of bins used for splitting features * @param quantileCalculationStrategy algorithm for calculating quantiles * @param categoricalFeaturesInfo A map storing information about the categorical variables and From 23d42684e02f7bb60cf55bae9e870dc5fde331c9 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 26 May 2014 15:01:06 -0700 Subject: [PATCH 50/72] minor: another minor code style --- .../main/scala/org/apache/spark/mllib/tree/DecisionTree.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 859f4f72957c..8e7a6917946b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -292,8 +292,7 @@ object DecisionTree extends Serializable with Logging { maxDepth: Int, numClassesForClassification: Int, labelWeights: Map[Int,Int]): DecisionTreeModel = { - val strategy - = new Strategy(algo, impurity, maxDepth, numClassesForClassification, + val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, labelWeights = labelWeights) // Converting from standard instance format to weighted input format for tree training val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features)) From e3e88433a5b106d4a4b08e97393462c55764f554 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 26 May 2014 18:43:14 -0700 Subject: [PATCH 51/72] minor code formatting --- .../spark/examples/mllib/DecisionTreeRunner.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 22c344a7dab9..825a1de7291a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -125,11 +125,11 @@ object DecisionTreeRunner { val strategy = new Strategy( - algo = params.algo, - impurity = impurityCalculator, - maxDepth = params.maxDepth, - maxBins = params.maxBins, - numClassesForClassification = params.numClassesForClassification) + algo = params.algo, + impurity = impurityCalculator, + maxDepth = params.maxDepth, + maxBins = params.maxBins, + numClassesForClassification = params.numClassesForClassification) val model = DecisionTree.train(training, strategy) if (params.algo == Classification) { From adc7315c81aac6df589be8578700ef0a4d691e56 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Wed, 4 Jun 2014 11:01:25 -0700 Subject: [PATCH 52/72] support ordered categorical splits for multiclass classification --- .../spark/mllib/tree/DecisionTree.scala | 95 ++++++++++++------ .../spark/mllib/tree/DecisionTreeSuite.scala | 96 ++++++++++++++++++- 2 files changed, 160 insertions(+), 31 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 8e7a6917946b..f5054dbf0d76 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -516,7 +516,7 @@ object DecisionTree extends Serializable with Logging { * Find bin for one feature. */ def findBin(featureIndex: Int, labeledPoint: WeightedLabeledPoint, - isFeatureContinuous: Boolean): Int = { + isFeatureContinuous: Boolean, isSpaceSufficientForAllCategoricalSplits: Boolean): Int = { val binForFeatures = bins(featureIndex) val feature = labeledPoint.features(featureIndex) @@ -550,14 +550,14 @@ object DecisionTree extends Serializable with Logging { * splits. The actual left/right child allocation per split is performed in the * sequential phase of the bin aggregate operation. */ - def sequentialBinSearchForCategoricalFeatureInMulticlassClassification(): Int = { + def sequentialBinSearchForUnorderedCategoricalFeatureInClassification(): Int = { labeledPoint.features(featureIndex).toInt } /** * Sequential search helper method to find bin for categorical feature. */ - def sequentialBinSearchForCategoricalFeatureInBinaryClassification(): Int = { + def sequentialBinSearchForOrderedCategoricalFeatureInClassification(): Int = { val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1 var binIndex = 0 @@ -583,10 +583,10 @@ object DecisionTree extends Serializable with Logging { } else { // Perform sequential search to find bin for categorical features. val binIndex = { - if (isMulticlassClassification) { - sequentialBinSearchForCategoricalFeatureInMulticlassClassification() + if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { + sequentialBinSearchForUnorderedCategoricalFeatureInClassification() } else { - sequentialBinSearchForCategoricalFeatureInBinaryClassification() + sequentialBinSearchForOrderedCategoricalFeatureInClassification() } } if (binIndex == -1){ @@ -622,8 +622,19 @@ object DecisionTree extends Serializable with Logging { } else { var featureIndex = 0 while (featureIndex < numFeatures) { - val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - arr(shift + featureIndex) = findBin(featureIndex, labeledPoint,isFeatureContinuous) + val featureInfo = strategy.categoricalFeaturesInfo.get(featureIndex) + val isFeatureContinuous = featureInfo.isEmpty + if (isFeatureContinuous) { + arr(shift + featureIndex) + = findBin(featureIndex, labeledPoint, isFeatureContinuous, false) + } else { + val featureCategories = featureInfo.get + val isSpaceSufficientForAllCategoricalSplits + = numBins > math.pow(2, featureCategories.toInt - 1) - 1 + arr(shift + featureIndex) + = findBin(featureIndex, labeledPoint, isFeatureContinuous, + isSpaceSufficientForAllCategoricalSplits) + } featureIndex += 1 } } @@ -731,12 +742,19 @@ object DecisionTree extends Serializable with Logging { // Iterate over all features. var featureIndex = 0 while (featureIndex < numFeatures) { - val isContinuousFeature = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty - if (isContinuousFeature) { + val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinuous) { updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex) } else { - updateBinForUnorderedFeature(nodeIndex, featureIndex, arr, label, agg, rightChildShift) + val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) + val isSpaceSufficientForAllCategoricalSplits + = numBins > math.pow(2, featureCategories.toInt - 1) - 1 + if (isSpaceSufficientForAllCategoricalSplits) { + updateBinForUnorderedFeature(nodeIndex, featureIndex, arr, label, agg, rightChildShift) + } else { + updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex) } + } featureIndex += 1 } } @@ -1093,7 +1111,14 @@ object DecisionTree extends Serializable with Logging { if (isFeatureContinuous) { findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) } else { - findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) + val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) + val isSpaceSufficientForAllCategoricalSplits + = numBins > math.pow(2, featureCategories.toInt - 1) - 1 + if (isSpaceSufficientForAllCategoricalSplits) { + findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) + } else { + findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) + } } } else { findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) @@ -1168,7 +1193,9 @@ object DecisionTree extends Serializable with Logging { numBins - 1 } else { // Categorical feature val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) - if (isMulticlassClassification) { + val isSpaceSufficientForAllCategoricalSplits + = numBins > math.pow(2, featureCategories.toInt - 1) - 1 + if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { math.pow(2.0, featureCategories - 1).toInt - 1 } else { // Binary classification featureCategories @@ -1289,11 +1316,6 @@ object DecisionTree extends Serializable with Logging { val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2 require(numBins > maxCategoriesForFeatures, "numBins should be greater than max categories " + "in categorical features") - if (isMulticlassClassification) { - require(numBins > math.pow(2, maxCategoriesForFeatures.toInt - 1) - 1, - "numBins should be greater than 2^(maxNumCategories-1) -1 for multiclass classification" + - " with categorical variables") - } } @@ -1332,10 +1354,12 @@ object DecisionTree extends Serializable with Logging { } } else { // Categorical feature val featureCategories = strategy.categoricalFeaturesInfo(featureIndex) + val isSpaceSufficientForAllCategoricalSplits + = numBins > math.pow(2, featureCategories.toInt - 1) - 1 // Use different bin/split calculation strategy for categorical features in multiclass - // classification - if (isMulticlassClassification) { + // classification that satisfy the space constraint + if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) { // 2^(maxFeatureValue- 1) - 1 combinations var index = 0 while (index < math.pow(2.0, featureCategories - 1).toInt - 1) { @@ -1360,14 +1384,29 @@ object DecisionTree extends Serializable with Logging { } index += 1 } - } else { // regression or binary classification - - // For categorical variables, each bin is a category. The bins are sorted and they - // are ordered by calculating the centroid of their corresponding labels. - val centroidForCategories = - sampledInput.map(lp => (lp.features(featureIndex),lp.label)) - .groupBy(_._1) - .mapValues(x => x.map(_._2).sum / x.map(_._1).length) + } else { + + val centroidForCategories = { + if (isMulticlassClassification) { + // For categorical variables in multiclass classification, + // each bin is a category. The bins are sorted and they + // are ordered by calculating the impurity of their corresponding labels. + sampledInput.map(lp => (lp.features(featureIndex), lp.label)) + .groupBy(_._1) + .mapValues(x => x.groupBy(_._2).mapValues(x => x.size.toDouble)) + .map(x => (x._1, x._2.values.toArray)) + .map(x => (x._1, strategy.impurity.calculate(x._2,x._2.sum))) + } else { // regression or binary classification + // For categorical variables in regression and binary classification, + // each bin is a category. The bins are sorted and they + // are ordered by calculating the centroid of their corresponding labels. + sampledInput.map(lp => (lp.features(featureIndex), lp.label)) + .groupBy(_._1) + .mapValues(x => x.map(_._2).sum / x.map(_._1).length) + } + } + + logDebug("centriod for categories = " + centroidForCategories.mkString(",")) // Check for missing categorical variables and putting them last in the sorted list. val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]() diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index c06ad055afee..6a6ad5b87132 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -239,7 +239,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(List(3.0, 2.0, 0.0).toSeq == l.toSeq) } - test("split and bin calculations for categorical variables with multiclass classification") { + test("split and bin calculations for unordered categorical variables with multiclass " + + "classification") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -332,6 +333,62 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { } + test("split and bin calculations for ordered categorical variables with multiclass " + + "classification") { + val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() + assert(arr.length === 3000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy( + Classification, + Gini, + maxDepth = 3, + numClassesForClassification = 100, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 10, 1-> 10)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + + // 2^10 - 1 > 100, so categorical variables will be ordered + + assert(splits(0)(0).feature === 0) + assert(splits(0)(0).threshold === Double.MinValue) + assert(splits(0)(0).featureType === Categorical) + assert(splits(0)(0).categories.length === 1) + assert(splits(0)(0).categories.contains(1.0)) + + assert(splits(0)(1).feature === 0) + assert(splits(0)(1).threshold === Double.MinValue) + assert(splits(0)(1).featureType === Categorical) + assert(splits(0)(1).categories.length === 2) + assert(splits(0)(1).categories.contains(2.0)) + + assert(splits(0)(2).feature === 0) + assert(splits(0)(2).threshold === Double.MinValue) + assert(splits(0)(2).featureType === Categorical) + assert(splits(0)(2).categories.length === 3) + assert(splits(0)(2).categories.contains(2.0)) + assert(splits(0)(2).categories.contains(1.0)) + + assert(splits(0)(10) === null) + assert(splits(1)(10) === null) + + + // Check bins. + + assert(bins(0)(0).category === 1.0) + assert(bins(0)(0).lowSplit.categories.length === 0) + assert(bins(0)(0).highSplit.categories.length === 1) + assert(bins(0)(0).highSplit.categories.contains(1.0)) + assert(bins(0)(1).category === 2.0) + assert(bins(0)(1).lowSplit.categories.length === 1) + assert(bins(0)(1).highSplit.categories.length === 2) + assert(bins(0)(1).highSplit.categories.contains(1.0)) + assert(bins(0)(1).highSplit.categories.contains(2.0)) + + assert(bins(0)(10) === null) + + } + + test("classification stump with all categorical variables") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) @@ -547,7 +604,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplit.featureType === Categorical) } - test("stump with continuous variables for multiclass classification") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() val input = sc.parallelize(arr) @@ -568,7 +624,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { } - test("stump with continuous + categorical variables for multiclass classification") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() val input = sc.parallelize(arr) @@ -588,6 +643,26 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplit.threshold < 2020) } + test("stump with categorical variables for ordered multiclass classification") { + val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() + val input = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) + assert(strategy.isMulticlassClassification) + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0, + Array[List[Filter]](), splits, bins, 10) + + assert(bestSplits.length === 1) + val bestSplit = bestSplits(0)._1 + assert(bestSplit.feature === 0) + assert(bestSplit.categories.length === 1) + println(bestSplit) + assert(bestSplit.categories.contains(1.0)) + assert(bestSplit.featureType === Categorical) + } + + } object DecisionTreeSuite { @@ -662,5 +737,20 @@ object DecisionTreeSuite { arr } + def generateCategoricalDataPointsForMulticlassForOrderedFeatures(): + Array[WeightedLabeledPoint] = { + val arr = new Array[WeightedLabeledPoint](3000) + for (i <- 0 until 3000) { + if (i < 1000) { + arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, 2.0)) + } else if (i < 2000) { + arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(1.0, 2.0)) + } else { + arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(2.0, 2.0)) + } + } + arr + } + } From 8e44ab81b3c7a019eb48586d75495b43987683fb Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Wed, 4 Jun 2014 12:20:15 -0700 Subject: [PATCH 53/72] updated doc --- docs/mllib-decision-tree.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index 8f8f54a3d92d..0f752f22243e 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -77,7 +77,7 @@ bins if the condition is not satisfied. **Categorical features** -For `$M$` categorical features, one could come up with `$2^(M-1)-1$` split candidates. For +For `$M$` categorical feature values, one could come up with `$2^(M-1)-1$` split candidates. For binary classification, we can reduce the number of split candidates to `$M-1$` by ordering the categorical feature values by the proportion of labels falling in one of the two classes (see Section 9.2.4 in @@ -85,7 +85,9 @@ Section 9.2.4 in details). For example, for a binary classification problem with one categorical feature with three categories A, B and C with corresponding proportion of label 1 as 0.2, 0.6 and 0.4, the categorical features are ordered as A followed by C followed B or A, B, C. The two split candidates are A \| C, B -and A , B \| C where \| denotes the split. +and A , B \| C where \| denotes the split. A similar ordering using impurity is performed +for categorical feature values in multiclass classification when `$2^(M-1)-1$` is +greater than the number of bins. ### Stopping rule From 3d7f911efa628f5580b0cba0830080d538a4cc71 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Thu, 5 Jun 2014 01:38:05 -0700 Subject: [PATCH 54/72] updated doc --- docs/mllib-decision-tree.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index 0f752f22243e..6109991121fe 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -85,9 +85,9 @@ Section 9.2.4 in details). For example, for a binary classification problem with one categorical feature with three categories A, B and C with corresponding proportion of label 1 as 0.2, 0.6 and 0.4, the categorical features are ordered as A followed by C followed B or A, B, C. The two split candidates are A \| C, B -and A , B \| C where \| denotes the split. A similar ordering using impurity is performed -for categorical feature values in multiclass classification when `$2^(M-1)-1$` is -greater than the number of bins. +and A , B \| C where \| denotes the split. A similar heuristic is used for multiclass classification +when `$2^(M-1)-1$` is greater than the number of bins -- the impurity for each categorical feature value +is used for ordering. ### Stopping rule From 485eaaef278f7688851df5b4948b74b5df835ce0 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Sun, 6 Jul 2014 23:03:22 -0700 Subject: [PATCH 55/72] implicit conversion from LabeledPoint to WeightedLabeledPoint --- .../spark/mllib/point/PointConverter.scala | 30 +++++++++++++++++++ .../spark/mllib/tree/DecisionTree.scala | 21 ++++--------- 2 files changed, 36 insertions(+), 15 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/point/PointConverter.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/point/PointConverter.scala b/mllib/src/main/scala/org/apache/spark/mllib/point/PointConverter.scala new file mode 100644 index 000000000000..022d2304423d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/point/PointConverter.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.point + +import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.regression.LabeledPoint + +object PointConverter { + + implicit def LabeledPoint2WeightedLabeledPoint( + points : RDD[LabeledPoint]): RDD[WeightedLabeledPoint] = { + points.map(point => new WeightedLabeledPoint(point.label,point.features)) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index f5054dbf0d76..d1c522f203bf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -17,6 +17,7 @@ package org.apache.spark.mllib.tree +import org.apache.spark.mllib.point.PointConverter._ import org.apache.spark.annotation.Experimental import org.apache.spark.Logging import org.apache.spark.mllib.regression.LabeledPoint @@ -211,9 +212,7 @@ object DecisionTree extends Serializable with Logging { * @return a DecisionTreeModel that can be used for prediction */ def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = { - // Converting from standard instance format to weighted input format for tree training - val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features)) - new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint]) + new DecisionTree(strategy).train(input) } /** @@ -235,9 +234,7 @@ object DecisionTree extends Serializable with Logging { impurity: Impurity, maxDepth: Int): DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth) - // Converting from standard instance format to weighted input format for tree training - val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features)) - new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint]) + new DecisionTree(strategy).train(input) } /** @@ -261,9 +258,7 @@ object DecisionTree extends Serializable with Logging { maxDepth: Int, numClassesForClassification: Int): DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification) - // Converting from standard instance format to weighted input format for tree training - val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features)) - new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint]) + new DecisionTree(strategy).train(input) } @@ -294,9 +289,7 @@ object DecisionTree extends Serializable with Logging { labelWeights: Map[Int,Int]): DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, labelWeights = labelWeights) - // Converting from standard instance format to weighted input format for tree training - val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features)) - new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint]) + new DecisionTree(strategy).train(input) } /** @@ -337,9 +330,7 @@ object DecisionTree extends Serializable with Logging { categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins, quantileCalculationStrategy, categoricalFeaturesInfo, labelWeights = labelWeights) - // Converting from standard instance format to weighted input format for tree training - val weightedInput = input.map(x => WeightedLabeledPoint(x.label, x.features)) - new DecisionTree(strategy).train(weightedInput: RDD[WeightedLabeledPoint]) + new DecisionTree(strategy).train(input) } private val InvalidBinIndex = -1 From 5c1b2cabf15247ce05d07e5636ff5c734fd1a4dd Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Tue, 8 Jul 2014 16:22:35 -0700 Subject: [PATCH 56/72] doc for PointConverter class --- .../scala/org/apache/spark/mllib/point/PointConverter.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/point/PointConverter.scala b/mllib/src/main/scala/org/apache/spark/mllib/point/PointConverter.scala index 022d2304423d..2da986aee862 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/point/PointConverter.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/point/PointConverter.scala @@ -20,6 +20,9 @@ package org.apache.spark.mllib.point import org.apache.spark.rdd.RDD import org.apache.spark.mllib.regression.LabeledPoint +/** + * Class to convert between different point formats. + */ object PointConverter { implicit def LabeledPoint2WeightedLabeledPoint( From 9cc3e315434df8e25eb17f14b4fa7ff4f05d2a23 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Tue, 8 Jul 2014 16:36:07 -0700 Subject: [PATCH 57/72] added implicit conversion import --- .../scala/org/apache/spark/mllib/point/PointConverter.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/point/PointConverter.scala b/mllib/src/main/scala/org/apache/spark/mllib/point/PointConverter.scala index 2da986aee862..1f31c4dadc21 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/point/PointConverter.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/point/PointConverter.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.point +import scala.language.implicitConversions + import org.apache.spark.rdd.RDD import org.apache.spark.mllib.regression.LabeledPoint From 06b16906b1b6f4ecd74fe24c6708bc941166d448 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Tue, 8 Jul 2014 17:01:51 -0700 Subject: [PATCH 58/72] fixed off-by-one error in bin to split conversion --- .../scala/org/apache/spark/mllib/tree/DecisionTree.scala | 8 ++++---- .../org/apache/spark/mllib/tree/DecisionTreeSuite.scala | 9 +++------ 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index d1c522f203bf..a9a108420564 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -1012,7 +1012,7 @@ object DecisionTree extends Serializable with Logging { = binData(shift + numClasses * splitIndex + innerClassIndex) + leftNodeAgg(featureIndex)(splitIndex - 1)(innerClassIndex) rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(innerClassIndex) = - binData(shift + (numClasses * (numBins - 2 - splitIndex) + innerClassIndex)) + + binData(shift + (numClasses * (numBins - 1 - splitIndex) + innerClassIndex)) + rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(innerClassIndex) innerClassIndex += 1 } @@ -1077,13 +1077,13 @@ object DecisionTree extends Serializable with Logging { // calculating right node aggregate for a split as a sum of right node aggregate of a // higher split and the right bin aggregate of a bin where the split is a low split rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(0) = - binData(shift + (3 * (numBins - 2 - splitIndex))) + + binData(shift + (3 * (numBins - 1 - splitIndex))) + rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(0) rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(1) = - binData(shift + (3 * (numBins - 2 - splitIndex) + 1)) + + binData(shift + (3 * (numBins - 1 - splitIndex) + 1)) + rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(1) rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(2) = - binData(shift + (3 * (numBins - 2 - splitIndex) + 2)) + + binData(shift + (3 * (numBins - 1 - splitIndex) + 2)) + rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(2) splitIndex += 1 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 6a6ad5b87132..6b6cab97935b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -412,9 +412,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val stats = bestSplits(0)._2 assert(stats.gain > 0) - assert(stats.predict === 0) - assert(stats.prob > 0.5) - assert(stats.prob < 0.6) + assert(stats.predict === 1) + assert(stats.prob == 0.6) assert(stats.impurity > 0.2) } @@ -440,8 +439,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val stats = bestSplits(0)._2 assert(stats.gain > 0) - assert(stats.predict > 0.4) - assert(stats.predict < 0.5) + assert(stats.predict == 0.6) assert(stats.impurity > 0.2) } @@ -657,7 +655,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val bestSplit = bestSplits(0)._1 assert(bestSplit.feature === 0) assert(bestSplit.categories.length === 1) - println(bestSplit) assert(bestSplit.categories.contains(1.0)) assert(bestSplit.featureType === Categorical) } From 0fecd381cbb0c90b7795f8f012a6fcb38f815f38 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Thu, 10 Jul 2014 15:52:38 -0700 Subject: [PATCH 59/72] minor: add newline to EOF --- .../org/apache/spark/mllib/point/WeightedLabeledPoint.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/point/WeightedLabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/point/WeightedLabeledPoint.scala index f7effcf182db..cc2d3caa0b86 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/point/WeightedLabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/point/WeightedLabeledPoint.scala @@ -29,4 +29,4 @@ case class WeightedLabeledPoint(label: Double, features: Vector, weight:Double = override def toString: String = { "LabeledPoint(%s, %s, %s)".format(label, features, weight) } -} \ No newline at end of file +} From d75ac3211f4b951e6771451894f7b24718f7c08c Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Thu, 10 Jul 2014 17:48:37 -0700 Subject: [PATCH 60/72] removed WeightedLabeledPoint from this PR --- .../spark/mllib/point/PointConverter.scala | 35 ----------- .../mllib/point/WeightedLabeledPoint.scala | 32 ---------- .../spark/mllib/tree/DecisionTree.scala | 16 +++-- .../spark/mllib/tree/DecisionTreeSuite.scala | 58 +++++++++---------- 4 files changed, 36 insertions(+), 105 deletions(-) delete mode 100644 mllib/src/main/scala/org/apache/spark/mllib/point/PointConverter.scala delete mode 100644 mllib/src/main/scala/org/apache/spark/mllib/point/WeightedLabeledPoint.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/point/PointConverter.scala b/mllib/src/main/scala/org/apache/spark/mllib/point/PointConverter.scala deleted file mode 100644 index 1f31c4dadc21..000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/point/PointConverter.scala +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.point - -import scala.language.implicitConversions - -import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.regression.LabeledPoint - -/** - * Class to convert between different point formats. - */ -object PointConverter { - - implicit def LabeledPoint2WeightedLabeledPoint( - points : RDD[LabeledPoint]): RDD[WeightedLabeledPoint] = { - points.map(point => new WeightedLabeledPoint(point.label,point.features)) - } - -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/point/WeightedLabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/point/WeightedLabeledPoint.scala deleted file mode 100644 index cc2d3caa0b86..000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/point/WeightedLabeledPoint.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.point - -import org.apache.spark.mllib.linalg.Vector - -/** - * Class that represents the features and labels of a data point. - * - * @param label Label for this data point. - * @param features List of features for this data point. - */ -case class WeightedLabeledPoint(label: Double, features: Vector, weight:Double = 1) { - override def toString: String = { - "LabeledPoint(%s, %s, %s)".format(label, features, weight) - } -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 9bda064bee55..a4524319c7ff 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -17,7 +17,6 @@ package org.apache.spark.mllib.tree -import org.apache.spark.mllib.point.PointConverter._ import org.apache.spark.annotation.Experimental import org.apache.spark.Logging import org.apache.spark.mllib.regression.LabeledPoint @@ -29,7 +28,6 @@ import org.apache.spark.mllib.tree.impurity.Impurity import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom -import org.apache.spark.mllib.point.WeightedLabeledPoint /** * :: Experimental :: @@ -47,7 +45,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data * @return a DecisionTreeModel that can be used for prediction */ - def train(input: RDD[WeightedLabeledPoint]): DecisionTreeModel = { + def train(input: RDD[LabeledPoint]): DecisionTreeModel = { // Cache input RDD for speedup during multiple passes. input.cache() @@ -352,7 +350,7 @@ object DecisionTree extends Serializable with Logging { * @return array of splits with best splits for all nodes at a given level. */ protected[tree] def findBestSplits( - input: RDD[WeightedLabeledPoint], + input: RDD[LabeledPoint], parentImpurities: Array[Double], strategy: Strategy, level: Int, @@ -400,7 +398,7 @@ object DecisionTree extends Serializable with Logging { * @return array of splits with best splits for all nodes at a given level. */ private def findBestSplitsPerGroup( - input: RDD[WeightedLabeledPoint], + input: RDD[LabeledPoint], parentImpurities: Array[Double], strategy: Strategy, level: Int, @@ -469,7 +467,7 @@ object DecisionTree extends Serializable with Logging { * Find whether the sample is valid input for the current node, i.e., whether it passes through * all the filters for the current node. */ - def isSampleValid(parentFilters: List[Filter], labeledPoint: WeightedLabeledPoint): Boolean = { + def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = { // leaf if ((level > 0) && (parentFilters.length == 0)) { return false @@ -506,7 +504,7 @@ object DecisionTree extends Serializable with Logging { /** * Find bin for one feature. */ - def findBin(featureIndex: Int, labeledPoint: WeightedLabeledPoint, + def findBin(featureIndex: Int, labeledPoint: LabeledPoint, isFeatureContinuous: Boolean, isSpaceSufficientForAllCategoricalSplits: Boolean): Int = { val binForFeatures = bins(featureIndex) val feature = labeledPoint.features(featureIndex) @@ -595,7 +593,7 @@ object DecisionTree extends Serializable with Logging { * classification and the categorical feature value in multiclass classification. * Invalid sample is denoted by noting bin for feature 1 as -1. */ - def findBinsForLevel(labeledPoint: WeightedLabeledPoint): Array[Double] = { + def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = { // Calculate bin index and label per feature per node. val arr = new Array[Double](1 + (numFeatures * numNodes)) // First element of the array is the label of the instance. @@ -1283,7 +1281,7 @@ object DecisionTree extends Serializable with Logging { * .spark.mllib.tree.model.Bin] of size (numFeatures, numSplits1) */ protected[tree] def findSplitsBins( - input: RDD[WeightedLabeledPoint], + input: RDD[LabeledPoint], strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = { val count = input.count() diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 6b6cab97935b..5961a618c59d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.mllib.tree import org.scalatest.FunSuite -import org.apache.spark.mllib.point.WeightedLabeledPoint import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} import org.apache.spark.mllib.tree.model.Filter import org.apache.spark.mllib.tree.model.Split @@ -28,6 +27,7 @@ import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.regression.LabeledPoint class DecisionTreeSuite extends FunSuite with LocalSparkContext { @@ -664,86 +664,86 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { object DecisionTreeSuite { - def generateOrderedLabeledPointsWithLabel0(): Array[WeightedLabeledPoint] = { - val arr = new Array[WeightedLabeledPoint](1000) + def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](1000) for (i <- 0 until 1000) { - val lp = new WeightedLabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i)) + val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i)) arr(i) = lp } arr } - def generateOrderedLabeledPointsWithLabel1(): Array[WeightedLabeledPoint] = { - val arr = new Array[WeightedLabeledPoint](1000) + def generateOrderedLabeledPointsWithLabel1(): Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](1000) for (i <- 0 until 1000) { - val lp = new WeightedLabeledPoint(1.0, Vectors.dense(i.toDouble, 999.0 - i)) + val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 999.0 - i)) arr(i) = lp } arr } - def generateOrderedLabeledPoints(): Array[WeightedLabeledPoint] = { - val arr = new Array[WeightedLabeledPoint](1000) + def generateOrderedLabeledPoints(): Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](1000) for (i <- 0 until 1000) { if (i < 600) { - val lp = new WeightedLabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i)) + val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i)) arr(i) = lp } else { - val lp = new WeightedLabeledPoint(1.0, Vectors.dense(i.toDouble, 1000.0 - i)) + val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 1000.0 - i)) arr(i) = lp } } arr } - def generateCategoricalDataPoints(): Array[WeightedLabeledPoint] = { - val arr = new Array[WeightedLabeledPoint](1000) + def generateCategoricalDataPoints(): Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](1000) for (i <- 0 until 1000) { if (i < 600) { - arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(0.0, 1.0)) + arr(i) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)) } else { - arr(i) = new WeightedLabeledPoint(0.0, Vectors.dense(1.0, 0.0)) + arr(i) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0)) } } arr } - def generateCategoricalDataPointsForMulticlass(): Array[WeightedLabeledPoint] = { - val arr = new Array[WeightedLabeledPoint](3000) + def generateCategoricalDataPointsForMulticlass(): Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](3000) for (i <- 0 until 3000) { if (i < 1000) { - arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, 2.0)) + arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0)) } else if (i < 2000) { - arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(1.0, 2.0)) + arr(i) = new LabeledPoint(1.0, Vectors.dense(1.0, 2.0)) } else { - arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, 2.0)) + arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0)) } } arr } - def generateContinuousDataPointsForMulticlass(): Array[WeightedLabeledPoint] = { - val arr = new Array[WeightedLabeledPoint](3000) + def generateContinuousDataPointsForMulticlass(): Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](3000) for (i <- 0 until 3000) { if (i < 2000) { - arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, i)) + arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, i)) } else { - arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(2.0, i)) + arr(i) = new LabeledPoint(1.0, Vectors.dense(2.0, i)) } } arr } def generateCategoricalDataPointsForMulticlassForOrderedFeatures(): - Array[WeightedLabeledPoint] = { - val arr = new Array[WeightedLabeledPoint](3000) + Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](3000) for (i <- 0 until 3000) { if (i < 1000) { - arr(i) = new WeightedLabeledPoint(2.0, Vectors.dense(2.0, 2.0)) + arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0)) } else if (i < 2000) { - arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(1.0, 2.0)) + arr(i) = new LabeledPoint(1.0, Vectors.dense(1.0, 2.0)) } else { - arr(i) = new WeightedLabeledPoint(1.0, Vectors.dense(2.0, 2.0)) + arr(i) = new LabeledPoint(1.0, Vectors.dense(2.0, 2.0)) } } arr From e4c1321dee8d4b5632d65d8761d1ab5c896b0c14 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 14 Jul 2014 13:54:05 -0700 Subject: [PATCH 61/72] using while loop for regression histograms --- .../spark/mllib/tree/DecisionTree.scala | 34 +++++++------------ 1 file changed, 13 insertions(+), 21 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index a4524319c7ff..3c02ee99409a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -1063,27 +1063,19 @@ object DecisionTree extends Serializable with Logging { // Iterate over all splits. var splitIndex = 1 while (splitIndex < numBins - 1) { - // calculating left node aggregate for a split as a sum of left node aggregate of a - // lower split and the left bin aggregate of a bin where the split is a high split - leftNodeAgg(featureIndex)(splitIndex)(0) = binData(shift + 3 * splitIndex) + - leftNodeAgg(featureIndex)(splitIndex - 1)(0) - leftNodeAgg(featureIndex)(splitIndex)(1) = binData(shift + 3 * splitIndex + 1) + - leftNodeAgg(featureIndex)(splitIndex - 1)(1) - leftNodeAgg(featureIndex)(splitIndex)(2) = binData(shift + 3 * splitIndex + 2) + - leftNodeAgg(featureIndex)(splitIndex - 1)(2) - - // calculating right node aggregate for a split as a sum of right node aggregate of a - // higher split and the right bin aggregate of a bin where the split is a low split - rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(0) = - binData(shift + (3 * (numBins - 1 - splitIndex))) + - rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(0) - rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(1) = - binData(shift + (3 * (numBins - 1 - splitIndex) + 1)) + - rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(1) - rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(2) = - binData(shift + (3 * (numBins - 1 - splitIndex) + 2)) + - rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(2) - + var i = 0 // index for regression histograms + while (i < 3) { // count, sum, sum^2 + // calculating left node aggregate for a split as a sum of left node aggregate of a + // lower split and the left bin aggregate of a bin where the split is a high split + leftNodeAgg(featureIndex)(splitIndex)(i) = binData(shift + 3 * splitIndex + i) + + leftNodeAgg(featureIndex)(splitIndex - 1)(i) + // calculating right node aggregate for a split as a sum of right node aggregate of a + // higher split and the right bin aggregate of a bin where the split is a low split + rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(i) = + binData(shift + (3 * (numBins - 1 - splitIndex) + i)) + + rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(i) + i += 1 + } splitIndex += 1 } } From b2ae41ff89cfa372f57764111eb424bbd8059acb Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 14 Jul 2014 14:06:42 -0700 Subject: [PATCH 62/72] minor: scalastyle --- .../main/scala/org/apache/spark/mllib/tree/DecisionTree.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 3c02ee99409a..32978d139f5d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -1467,7 +1467,7 @@ object DecisionTree extends Serializable with Logging { // updating the list of categories. categories = j.toDouble :: categories } - //Right shift by one + // Right shift by one bitShiftedInput = bitShiftedInput >> 1 j += 1 } From 4e85f2ceca8aa21f0aba14ef9d2e73dd968b64a9 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 14 Jul 2014 14:18:02 -0700 Subject: [PATCH 63/72] minor: fixed scalastyle issues --- .../spark/mllib/tree/DecisionTree.scala | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 32978d139f5d..f505ee50887c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -448,7 +448,8 @@ object DecisionTree extends Serializable with Logging { logDebug("isMulticlassClassification = " + isMulticlassClassification) val isMulticlassClassificationWithCategoricalFeatures = strategy.isMulticlassWithCategoricalFeatures - logDebug("isMultiClassWithCategoricalFeatures = " + isMulticlassClassificationWithCategoricalFeatures) + logDebug("isMultiClassWithCategoricalFeatures = " + + isMulticlassClassificationWithCategoricalFeatures) // shift when more than one group is used at deep tree level val groupShift = numNodes * groupIndex @@ -643,7 +644,8 @@ object DecisionTree extends Serializable with Logging { val arrIndex = arrShift + featureIndex // Update the left or right count for one bin. val aggShift = numClasses * numBins * numFeatures * nodeIndex - val aggIndex = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses + val aggIndex = aggShift + numClasses * featureIndex * numBins + + arr(arrIndex).toInt * numClasses val labelInt = label.toInt agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + labelWeights.getOrElse(labelInt, 1) } @@ -739,7 +741,8 @@ object DecisionTree extends Serializable with Logging { val isSpaceSufficientForAllCategoricalSplits = numBins > math.pow(2, featureCategories.toInt - 1) - 1 if (isSpaceSufficientForAllCategoricalSplits) { - updateBinForUnorderedFeature(nodeIndex, featureIndex, arr, label, agg, rightChildShift) + updateBinForUnorderedFeature(nodeIndex, featureIndex, arr, label, agg, + rightChildShift) } else { updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex) } @@ -909,10 +912,10 @@ object DecisionTree extends Serializable with Logging { .map{case (leftCount, rightCount) => leftCount + rightCount} def indexOfLargestArrayElement(array: Array[Double]): Int = { - val result = array.foldLeft(-1,Double.MinValue,0) { + val result = array.foldLeft(-1, Double.MinValue, 0) { case ((maxIndex, maxValue, currentIndex), currentValue) => - if(currentValue > maxValue) (currentIndex,currentValue,currentIndex+1) - else (maxIndex,maxValue,currentIndex+1) + if(currentValue > maxValue) (currentIndex, currentValue, currentIndex + 1) + else (maxIndex, maxValue, currentIndex+1) } if (result._1 < 0) 0 else result._1 } @@ -1408,8 +1411,8 @@ object DecisionTree extends Serializable with Logging { categoriesSortedByCentroid.iterator.zipWithIndex.foreach { case ((key, value), index) => categoriesForSplit = key :: categoriesForSplit - splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue, Categorical, - categoriesForSplit) + splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue, + Categorical, categoriesForSplit) bins(featureIndex)(index) = { if (index == 0) { new Bin(new DummyCategoricalSplit(featureIndex, Categorical), From 2d85a48477b469004a108019a1167ad6bb3a3117 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 14 Jul 2014 14:24:34 -0700 Subject: [PATCH 64/72] minor: fixed scalastyle issues reprise --- .../main/scala/org/apache/spark/mllib/tree/DecisionTree.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index f505ee50887c..b0f0bec899c8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -915,7 +915,7 @@ object DecisionTree extends Serializable with Logging { val result = array.foldLeft(-1, Double.MinValue, 0) { case ((maxIndex, maxValue, currentIndex), currentValue) => if(currentValue > maxValue) (currentIndex, currentValue, currentIndex + 1) - else (maxIndex, maxValue, currentIndex+1) + else (maxIndex, maxValue, currentIndex + 1) } if (result._1 < 0) 0 else result._1 } From afced168ac9bce8893ae08467c43b35ebb67fe28 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 14 Jul 2014 16:11:26 -0700 Subject: [PATCH 65/72] removed label weights support --- .../spark/mllib/tree/DecisionTree.scala | 46 ++----------------- .../mllib/tree/configuration/Strategy.scala | 7 +-- 2 files changed, 5 insertions(+), 48 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index b0f0bec899c8..fb53b588cdce 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -259,37 +259,6 @@ object DecisionTree extends Serializable with Logging { new DecisionTree(strategy).train(input) } - - /** - * Method to train a decision tree model where the instances are represented as an RDD of - * (label, features) pairs. The method supports binary classification and regression. For the - * binary classification, the label for each instance should either be 0 or 1 to denote the two - * classes. - * - * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as - * training data - * @param algo algorithm, classification or regression - * @param impurity impurity criterion used for information gain calculation - * @param maxDepth maxDepth maximum depth of the tree - * @param numClassesForClassification number of classes for classification. Default value of 2. - * @param labelWeights A map storing weights for each label to handle unbalanced - * datasets. For example, an entry (n -> k) implies the a weight of k is - * applied to an instance with label n. It's important to note that labels - * are zero-index and take values 0, 1, 2, ... , numClasses - 1. - * @return a DecisionTreeModel that can be used for prediction - */ - def train( - input: RDD[LabeledPoint], - algo: Algo, - impurity: Impurity, - maxDepth: Int, - numClassesForClassification: Int, - labelWeights: Map[Int,Int]): DecisionTreeModel = { - val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, - labelWeights = labelWeights) - new DecisionTree(strategy).train(input) - } - /** * Method to train a decision tree model where the instances are represented as an RDD of * (label, features) pairs. The decision tree method supports binary classification and @@ -303,10 +272,6 @@ object DecisionTree extends Serializable with Logging { * @param impurity criterion used for information gain calculation * @param maxDepth maximum depth of the tree * @param numClassesForClassification number of classes for classification. Default value of 2. - * @param labelWeights A map storing weights applied to each label for handling unbalanced - * datasets. For example, an entry (n -> k) implies the a weight of k is - * applied to an instance with label n. It's important to note that labels - * are zero-index and take values 0, 1, 2, ... , numClasses - 1. * @param maxBins maximum number of bins used for splitting features * @param quantileCalculationStrategy algorithm for calculating quantiles * @param categoricalFeaturesInfo A map storing information about the categorical variables and @@ -322,12 +287,11 @@ object DecisionTree extends Serializable with Logging { impurity: Impurity, maxDepth: Int, numClassesForClassification: Int, - labelWeights: Map[Int,Int], maxBins: Int, quantileCalculationStrategy: QuantileStrategy, categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins, - quantileCalculationStrategy, categoricalFeaturesInfo, labelWeights = labelWeights) + quantileCalculationStrategy, categoricalFeaturesInfo) new DecisionTree(strategy).train(input) } @@ -442,8 +406,6 @@ object DecisionTree extends Serializable with Logging { logDebug("numBins = " + numBins) val numClasses = strategy.numClassesForClassification logDebug("numClasses = " + numClasses) - val labelWeights = strategy.labelWeights - logDebug("labelWeights = " + labelWeights) val isMulticlassClassification = strategy.isMulticlassClassification logDebug("isMulticlassClassification = " + isMulticlassClassification) val isMulticlassClassificationWithCategoricalFeatures @@ -647,7 +609,7 @@ object DecisionTree extends Serializable with Logging { val aggIndex = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses val labelInt = label.toInt - agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + labelWeights.getOrElse(labelInt, 1) + agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + 1 } def updateBinForUnorderedFeature(nodeIndex: Int, featureIndex: Int, arr: Array[Double], @@ -667,10 +629,10 @@ object DecisionTree extends Serializable with Logging { val labelInt = label.toInt if (bins(featureIndex)(binIndex).highSplit.categories.contains(labelInt)) { agg(aggIndex + binIndex) - = agg(aggIndex + binIndex) + labelWeights.getOrElse(labelInt, 1) + = agg(aggIndex + binIndex) + 1 } else { agg(rightChildShift + aggIndex + binIndex) - = agg(rightChildShift + aggIndex + binIndex) + labelWeights.getOrElse(labelInt, 1) + = agg(rightChildShift + aggIndex + binIndex) + 1 } binIndex += 1 } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 7aec14d293ec..7c027ac2fda6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -39,10 +39,6 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * zero-indexed. * @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is * 128 MB. - * @param labelWeights A map storing weights applied to each label for handling unbalanced - * datasets. For example, an entry (n -> k) implies the a weight of k is - * applied to an instance with label n. It's important to note that labels - * are zero-index and take values 0, 1, 2, ... , numClasses. * */ @Experimental @@ -54,8 +50,7 @@ class Strategy ( val maxBins: Int = 100, val quantileCalculationStrategy: QuantileStrategy = Sort, val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), - val maxMemoryInMB: Int = 128, - val labelWeights: Map[Int, Int] = Map[Int, Int]()) extends Serializable { + val maxMemoryInMB: Int = 128) extends Serializable { require(numClassesForClassification >= 2) val isMulticlassClassification = numClassesForClassification > 2 From c8428c4e7480ba1758139e85095318d7bc38b0ea Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Mon, 14 Jul 2014 22:40:02 -0700 Subject: [PATCH 66/72] fixing weird multiline bug --- .../main/scala/org/apache/spark/mllib/tree/DecisionTree.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index fb53b588cdce..ad32e3f4560f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -606,8 +606,8 @@ object DecisionTree extends Serializable with Logging { val arrIndex = arrShift + featureIndex // Update the left or right count for one bin. val aggShift = numClasses * numBins * numFeatures * nodeIndex - val aggIndex = aggShift + numClasses * featureIndex * numBins - + arr(arrIndex).toInt * numClasses + val aggIndex + = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses val labelInt = label.toInt agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + 1 } From 45e767afaedcd8f99d390464d211690addbc1f9b Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Tue, 15 Jul 2014 12:53:24 -0700 Subject: [PATCH 67/72] adding developer api annotation for overriden methods --- .../apache/spark/mllib/tree/impurity/Entropy.scala | 8 ++++++++ .../org/apache/spark/mllib/tree/impurity/Gini.scala | 8 ++++++++ .../apache/spark/mllib/tree/impurity/Variance.scala | 11 ++++++++++- 3 files changed, 26 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index ead76d64b638..a0e2d9176278 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -52,6 +52,14 @@ object Entropy extends Impurity { impurity } + /** + * :: DeveloperApi :: + * variance calculation + * @param count number of instances + * @param sum sum of labels + * @param sumSquares summation of squares of the labels + */ + @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = throw new UnsupportedOperationException("Entropy.calculate") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index c8773fc4f860..48144b5e6d1e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -48,6 +48,14 @@ object Gini extends Impurity { impurity } + /** + * :: DeveloperApi :: + * variance calculation + * @param count number of instances + * @param sum sum of labels + * @param sumSquares summation of squares of the labels + */ + @DeveloperApi override def calculate(count: Double, sum: Double, sumSquares: Double): Double = throw new UnsupportedOperationException("Gini.calculate") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 555754b1ee03..97149a99ead5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -25,7 +25,16 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} */ @Experimental object Variance extends Impurity { - override def calculate(counts: Array[Double], totalCounts: Double): Double = + + /** + * :: DeveloperApi :: + * information calculation for multiclass classification + * @param counts Array[Double] with counts for each label + * @param totalCount sum of counts for all labels + * @return information value + */ + @DeveloperApi + override def calculate(counts: Array[Double], totalCount: Double): Double = throw new UnsupportedOperationException("Variance.calculate") /** From abf29014272aefe704700dca3c6782df323e07ce Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Thu, 17 Jul 2014 13:37:02 -0700 Subject: [PATCH 68/72] adding classes to MimaExcludes.scala --- project/MimaExcludes.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 3b7b87b80cda..c07d35496a60 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -81,6 +81,9 @@ object MimaExcludes { MimaBuild.excludeSparkClass("storage.Values") ++ MimaBuild.excludeSparkClass("storage.Entry") ++ MimaBuild.excludeSparkClass("storage.MemoryStore$Entry") + MimaBuild.excludeSparkClass("org.apache.spark.mllib.tree.impurity.Gini") + MimaBuild.excludeSparkClass("org.apache.spark.mllib.tree.impurity.Entropy") + MimaBuild.excludeSparkClass("org.apache.spark.mllib.tree.impurity.Variance") case v if v.startsWith("1.0") => Seq( MimaBuild.excludeSparkPackage("api.java"), From 10fdd826abb831d65052f63645690cd24d615e11 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Thu, 17 Jul 2014 15:32:12 -0700 Subject: [PATCH 69/72] fixing MIMA excludes --- project/MimaExcludes.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index cc453fb6bdff..24dafabfc4a5 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -82,9 +82,9 @@ object MimaExcludes { MimaBuild.excludeSparkClass("util.SerializableHyperLogLog") ++ MimaBuild.excludeSparkClass("storage.Values") ++ MimaBuild.excludeSparkClass("storage.Entry") ++ - MimaBuild.excludeSparkClass("storage.MemoryStore$Entry") - MimaBuild.excludeSparkClass("org.apache.spark.mllib.tree.impurity.Gini") - MimaBuild.excludeSparkClass("org.apache.spark.mllib.tree.impurity.Entropy") + MimaBuild.excludeSparkClass("storage.MemoryStore$Entry") ++ + MimaBuild.excludeSparkClass("org.apache.spark.mllib.tree.impurity.Gini") ++ + MimaBuild.excludeSparkClass("org.apache.spark.mllib.tree.impurity.Entropy") ++ MimaBuild.excludeSparkClass("org.apache.spark.mllib.tree.impurity.Variance") case v if v.startsWith("1.0") => Seq( From 1ce7212bc0ce9df0d084ea5fa936146fd655e78a Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Thu, 17 Jul 2014 20:03:40 -0700 Subject: [PATCH 70/72] change problem filter for mima --- project/MimaExcludes.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 24dafabfc4a5..17bf9eda7b76 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -83,9 +83,11 @@ object MimaExcludes { MimaBuild.excludeSparkClass("storage.Values") ++ MimaBuild.excludeSparkClass("storage.Entry") ++ MimaBuild.excludeSparkClass("storage.MemoryStore$Entry") ++ - MimaBuild.excludeSparkClass("org.apache.spark.mllib.tree.impurity.Gini") ++ - MimaBuild.excludeSparkClass("org.apache.spark.mllib.tree.impurity.Entropy") ++ - MimaBuild.excludeSparkClass("org.apache.spark.mllib.tree.impurity.Variance") + Seq( + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.tree.impurity.Gini.calculate") + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.tree.impurity.Entropy.calculate") + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.tree.impurity.Variance.calculate") + ) case v if v.startsWith("1.0") => Seq( MimaBuild.excludeSparkPackage("api.java"), From c5b2d04a3aad5e4ec01c35b13edfe43e83702be7 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Thu, 17 Jul 2014 20:11:07 -0700 Subject: [PATCH 71/72] more MIMA fixes --- project/MimaExcludes.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 17bf9eda7b76..757853b21a55 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -84,8 +84,8 @@ object MimaExcludes { MimaBuild.excludeSparkClass("storage.Entry") ++ MimaBuild.excludeSparkClass("storage.MemoryStore$Entry") ++ Seq( - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.tree.impurity.Gini.calculate") - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.tree.impurity.Entropy.calculate") + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.tree.impurity.Gini.calculate"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.tree.impurity.Entropy.calculate"), ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.tree.impurity.Variance.calculate") ) case v if v.startsWith("1.0") => From 26f8acc89433cb2a5f77cd9701a50a381cc5f3a2 Mon Sep 17 00:00:00 2001 From: Manish Amde Date: Thu, 17 Jul 2014 22:12:12 -0700 Subject: [PATCH 72/72] another attempt at fixing mima --- project/MimaExcludes.scala | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 757853b21a55..e0f433b26f7f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -84,9 +84,12 @@ object MimaExcludes { MimaBuild.excludeSparkClass("storage.Entry") ++ MimaBuild.excludeSparkClass("storage.MemoryStore$Entry") ++ Seq( - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.tree.impurity.Gini.calculate"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.tree.impurity.Entropy.calculate"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.tree.impurity.Variance.calculate") + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.mllib.tree.impurity.Gini.calculate"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.mllib.tree.impurity.Entropy.calculate"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]( + "org.apache.spark.mllib.tree.impurity.Variance.calculate") ) case v if v.startsWith("1.0") => Seq(