Skip to content

Commit e676da1

Browse files
committed
Updated documentation for DecisionTree
1 parent 37ca845 commit e676da1

File tree

10 files changed

+170
-70
lines changed

10 files changed

+170
-70
lines changed

examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,8 @@ object DecisionTreeRunner {
156156
throw new IllegalArgumentException("Algo ${params.algo} not supported.")
157157
}
158158

159-
println("opt3")
160159
// Split into training, test.
161-
val splits = examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest), seed = 12345)
160+
val splits = examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest))
162161
val training = splits(0).cache()
163162
val test = splits(1).cache()
164163
val numTraining = training.count()

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.spark.mllib.tree.configuration.Strategy
2828
import org.apache.spark.mllib.tree.configuration.Algo._
2929
import org.apache.spark.mllib.tree.configuration.FeatureType._
3030
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
31-
import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, DTStatsAggregator, TimeTracker, TreePoint}
31+
import org.apache.spark.mllib.tree.impl._
3232
import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity}
3333
import org.apache.spark.mllib.tree.impurity._
3434
import org.apache.spark.mllib.tree.model._
@@ -122,7 +122,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
122122
var break = false
123123
while (level <= maxDepth && !break) {
124124

125-
//println(s"LEVEL $level")
126125
logDebug("#####################################")
127126
logDebug("level = " + level)
128127
logDebug("#####################################")
@@ -198,14 +197,12 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
198197

199198
logInfo("Internal timing for DecisionTree:")
200199
logInfo(s"$timer")
201-
println(s"$timer")
202200

203201
new DecisionTreeModel(topNode, strategy.algo)
204202
}
205203

206204
}
207205

208-
209206
object DecisionTree extends Serializable with Logging {
210207

211208
/**
@@ -456,13 +453,21 @@ object DecisionTree extends Serializable with Logging {
456453
* This function mimics prediction, passing an example from the root node down to a node
457454
* at the current level being trained; that node's index is returned.
458455
*
456+
* @param node Node in tree from which to classify the given data point.
457+
* @param binnedFeatures Binned feature vector for data point.
458+
* @param bins possible bins for all features, indexed (numFeatures)(numBins)
459+
* @param unorderedFeatures Set of indices of unordered features.
459460
* @return Leaf index if the data point reaches a leaf.
460461
* Otherwise, last node reachable in tree matching this example.
461462
* Note: This is the global node index, i.e., the index used in the tree.
462463
* This index is different from the index used during training a particular
463464
* set of nodes in a (level, group).
464465
*/
465-
def predictNodeIndex(node: Node, binnedFeatures: Array[Int], bins: Array[Array[Bin]], unorderedFeatures: Set[Int]): Int = {
466+
def predictNodeIndex(
467+
node: Node,
468+
binnedFeatures: Array[Int],
469+
bins: Array[Array[Bin]],
470+
unorderedFeatures: Set[Int]): Int = {
466471
if (node.isLeaf) {
467472
node.id
468473
} else {
@@ -499,15 +504,18 @@ object DecisionTree extends Serializable with Logging {
499504
}
500505

501506
/**
502-
* Helper for binSeqOp.
507+
* Helper for binSeqOp, for data containing some unordered (categorical) features.
503508
*
504-
* @param agg Array storing aggregate calculation.
505-
* For ordered features, this is of size:
506-
* numClasses * numBins * numFeatures * numNodes.
507-
* For unordered features, this is of size:
508-
* 2 * numClasses * numBins * numFeatures * numNodes.
509-
* @param treePoint Data point being aggregated.
509+
* For ordered features, a single bin is updated.
510+
* For unordered features, bins correspond to subsets of categories; either the left or right bin
511+
* for each subset is updated.
512+
*
513+
* @param agg Array storing aggregate calculation, with a set of sufficient statistics for
514+
* each (node, feature, bin).
515+
* @param treePoint Data point being aggregated.
510516
* @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
517+
* @param bins possible bins for all features, indexed (numFeatures)(numBins)
518+
* @param unorderedFeatures Set of indices of unordered features.
511519
*/
512520
def someUnorderedBinSeqOp(
513521
agg: DTStatsAggregator,
@@ -547,15 +555,13 @@ object DecisionTree extends Serializable with Logging {
547555
}
548556

549557
/**
550-
* Helper for binSeqOp: for regression and for classification with only ordered features.
558+
* Helper for binSeqOp, for regression and for classification with only ordered features.
551559
*
552-
* Performs a sequential aggregation over a partition for regression.
553-
* For l nodes, k features,
554-
* the count, sum, sum of squares of one of the p bins is incremented.
560+
* For each feature, the sufficient statistics of one bin are updated.
555561
*
556-
* @param agg Array storing aggregate calculation, updated by this function.
557-
* Size: 3 * numBins * numFeatures * numNodes
558-
* @param treePoint Data point being aggregated.
562+
* @param agg Array storing aggregate calculation, with a set of sufficient statistics for
563+
* each (node, feature, bin).
564+
* @param treePoint Data point being aggregated.
559565
* @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
560566
* @return agg
561567
*/
@@ -582,6 +588,7 @@ object DecisionTree extends Serializable with Logging {
582588
* @param parentImpurities Impurities for all parent nodes for the current level
583589
* @param metadata Learning and dataset metadata
584590
* @param level Level of the tree
591+
* @param nodes Array of all nodes in the tree. Used for matching data points to nodes.
585592
* @param splits possible splits for all features, indexed (numFeatures)(numSplits)
586593
* @param bins possible bins for all features, indexed (numFeatures)(numBins)
587594
* @param numGroups total number of node groups at the current level. Default value is set to 1.
@@ -663,19 +670,12 @@ object DecisionTree extends Serializable with Logging {
663670

664671
/**
665672
* Performs a sequential aggregation over a partition.
666-
* For l nodes, k features,
667-
* For classification:
668-
* Either the left count or the right count of one of the bins is
669-
* incremented based upon whether the feature is classified as 0 or 1.
670-
* For regression:
671-
* The count, sum, sum of squares of one of the bins is incremented.
672673
*
673-
* @param agg Array storing aggregate calculation, updated by this function.
674-
* Size for classification:
675-
* Ordered features: numNodes * numFeatures * numBins.
676-
* Unordered features: (2 * numNodes) * numFeatures * numBins.
677-
* Size for regression:
678-
* numNodes * numFeatures * numBins.
674+
* Each data point contributes to one node. For each feature,
675+
* the aggregate sufficient statistics are updated for the relevant bins.
676+
*
677+
* @param agg Array storing aggregate calculation, with a set of sufficient statistics for
678+
* each (node, feature, bin).
679679
* @param treePoint Data point being aggregated.
680680
* @return agg
681681
*/
@@ -883,8 +883,10 @@ object DecisionTree extends Serializable with Logging {
883883
val (bestFeatureSplitIndex, bestFeatureGainStats) =
884884
Range(0, numSplits).map { splitIndex =>
885885
val featureValue = categoriesSortedByCentroid(splitIndex)._1
886-
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
887-
val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
886+
val leftChildStats =
887+
binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
888+
val rightChildStats =
889+
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
888890
rightChildStats.subtract(leftChildStats)
889891
val gainStats =
890892
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)

mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,36 +19,46 @@ package org.apache.spark.mllib.tree.impl
1919

2020
import org.apache.spark.mllib.tree.impurity._
2121

22-
import scala.collection.mutable
23-
24-
2522
/**
26-
* :: Experimental ::
2723
* DecisionTree statistics aggregator.
2824
* This holds a flat array of statistics for a set of (nodes, features, bins)
2925
* and helps with indexing.
30-
* TODO: Allow views of Vector types to replace some of the code in here.
3126
*/
3227
private[tree] class DTStatsAggregator(
3328
metadata: DecisionTreeMetadata,
3429
val numNodes: Int) extends Serializable {
3530

31+
/**
32+
* [[ImpurityAggregator]] instance specifying the impurity type.
33+
*/
3634
val impurityAggregator: ImpurityAggregator = metadata.impurity match {
3735
case Gini => new GiniAggregator(metadata.numClasses)
3836
case Entropy => new EntropyAggregator(metadata.numClasses)
3937
case Variance => new VarianceAggregator()
4038
case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}")
4139
}
4240

41+
/**
42+
* Number of elements (Double values) used for the sufficient statistics of each bin.
43+
*/
4344
val statsSize: Int = impurityAggregator.statsSize
4445

4546
val numFeatures: Int = metadata.numFeatures
4647

48+
/**
49+
* Number of bins for each feature. This is indexed by the feature index.
50+
*/
4751
val numBins: Array[Int] = metadata.numBins
4852

49-
val isUnordered: Array[Boolean] =
50-
Range(0, numFeatures).map(f => metadata.unorderedFeatures.contains(f)).toArray
53+
/**
54+
* Indicator for each feature of whether that feature is an unordered feature.
55+
* TODO: Is Array[Boolean] any faster?
56+
*/
57+
def isUnordered(featureIndex: Int): Boolean = metadata.isUnordered(featureIndex)
5158

59+
/**
60+
* Offset for each feature for calculating indices into the [[allStats]] array.
61+
*/
5262
private val featureOffsets: Array[Int] = {
5363
def featureOffsetsCalc(total: Int, featureIndex: Int): Int = {
5464
if (isUnordered(featureIndex)) {
@@ -105,8 +115,9 @@ private[tree] class DTStatsAggregator(
105115
def getNodeOffset(nodeIndex: Int): Int = nodeIndex * nodeStride
106116

107117
/**
118+
* Faster version of [[update]].
108119
* Update the stats for a given (node, feature, bin) for ordered features, using the given label.
109-
* This uses a pre-computed node offset from [[getNodeOffset]].
120+
* @param nodeOffset Pre-computed node offset from [[getNodeOffset]].
110121
*/
111122
def nodeUpdate(nodeOffset: Int, featureIndex: Int, binIndex: Int, label: Double): Unit = {
112123
val i = nodeOffset + featureOffsets(featureIndex) + binIndex * statsSize
@@ -137,6 +148,7 @@ private[tree] class DTStatsAggregator(
137148
}
138149

139150
/**
151+
* Faster version of [[update]].
140152
* Update the stats for a given (node, feature, bin), using the given label.
141153
* @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset
142154
* from [[getNodeFeatureOffset]].

mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,30 +24,17 @@ import org.apache.spark.mllib.tree.configuration.Algo._
2424
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
2525
import org.apache.spark.mllib.tree.configuration.Strategy
2626
import org.apache.spark.mllib.tree.impurity.Impurity
27-
import org.apache.spark.mllib.tree.DecisionTree
2827
import org.apache.spark.rdd.RDD
2928

30-
31-
/*
32-
* TODO: Add doc about ordered vs. unordered features.
33-
* Ensure numBins is always greater than the categories. For multiclass classification,
34-
* numBins should be greater than math.pow(2, maxCategories - 1) - 1.
35-
* It's a limitation of the current implementation but a reasonable trade-off since features
36-
* with large number of categories get favored over continuous features.
37-
*
38-
* This needs to be checked here instead of in Strategy since numBins can be determined
39-
* by the number of training examples.
40-
*/
41-
42-
4329
/**
4430
* Learning and dataset metadata for DecisionTree.
4531
*
4632
* @param numClasses For classification: labels can take values {0, ..., numClasses - 1}.
4733
* For regression: fixed at 0 (no meaning).
34+
* @param maxBins Maximum number of bins, for all features.
4835
* @param featureArity Map: categorical feature index --> arity.
4936
* I.e., the feature takes values in {0, ..., arity - 1}.
50-
* @param numBins numBins(featureIndex) = number of bins for feature
37+
* @param numBins Number of bins for each feature.
5138
*/
5239
private[tree] class DecisionTreeMetadata(
5340
val numFeatures: Int,
@@ -82,6 +69,11 @@ private[tree] class DecisionTreeMetadata(
8269

8370
private[tree] object DecisionTreeMetadata {
8471

72+
/**
73+
* Construct a [[DecisionTreeMetadata]] instance for this dataset and parameters.
74+
* This computes which categorical features will be ordered vs. unordered,
75+
* as well as the number of splits and bins for each feature.
76+
*/
8577
def buildMetadata(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeMetadata = {
8678

8779
val numFeatures = input.take(1)(0).features.size
@@ -94,6 +86,9 @@ private[tree] object DecisionTreeMetadata {
9486
val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt
9587
val log2MaxPossibleBinsp1 = math.log(maxPossibleBins + 1) / math.log(2.0)
9688

89+
// We check the number of bins here against maxPossibleBins.
90+
// This needs to be checked here instead of in Strategy since maxPossibleBins can be modified
91+
// based on the number of training examples.
9792
val unorderedFeatures = new mutable.HashSet[Int]()
9893
val numBins = Array.fill[Int](numFeatures)(maxPossibleBins)
9994
if (numClasses > 2) {
@@ -104,11 +99,6 @@ private[tree] object DecisionTreeMetadata {
10499
unorderedFeatures.add(f)
105100
numBins(f) = numUnorderedBins(k)
106101
} else {
107-
// TODO: Check the below k <= maxBins.
108-
// Checking k <= maxPossibleBins should work.
109-
// However, there may have been a 1-off error later on allocating 1 extra
110-
// (unused) bin.
111-
// TODO: Allow this case, where we simply will know nothing about some categories?
112102
require(k <= maxPossibleBins,
113103
s"maxBins (= $maxPossibleBins) should be greater than max categories " +
114104
s"in categorical features (>= $k)")

mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@ object Entropy extends Impurity {
7575

7676
}
7777

78+
/**
79+
* Class for updating views of a vector of sufficient statistics,
80+
* in order to compute impurity from a sample.
81+
* Note: Instances of this class do not hold the data; they operate on views of the data.
82+
* @param numClasses Number of classes for label.
83+
*/
7884
private[tree] class EntropyAggregator(numClasses: Int)
7985
extends ImpurityAggregator(numClasses) with Serializable {
8086

@@ -102,20 +108,41 @@ private[tree] class EntropyAggregator(numClasses: Int)
102108

103109
}
104110

111+
/**
112+
* Stores statistics for one (node, feature, bin) for calculating impurity.
113+
* Unlike [[EntropyAggregator]], this class stores its own data and is for a specific
114+
* (node, feature, bin).
115+
* @param stats Array of sufficient statistics for a (node, feature, bin).
116+
*/
105117
private[tree] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
106118

119+
/**
120+
* Make a deep copy of this [[ImpurityCalculator]].
121+
*/
107122
def copy: EntropyCalculator = new EntropyCalculator(stats.clone())
108123

124+
/**
125+
* Calculate the impurity from the stored sufficient statistics.
126+
*/
109127
def calculate(): Double = Entropy.calculate(stats, stats.sum)
110128

129+
/**
130+
* Number of data points accounted for in the sufficient statistics.
131+
*/
111132
def count: Long = stats.sum.toLong
112133

134+
/**
135+
* Prediction which should be made based on the sufficient statistics.
136+
*/
113137
def predict: Double = if (count == 0) {
114138
0
115139
} else {
116140
indexOfLargestArrayElement(stats)
117141
}
118142

143+
/**
144+
* Probability of the label given by [[predict]].
145+
*/
119146
override def prob(label: Double): Double = {
120147
val lbl = label.toInt
121148
require(lbl < stats.length,

0 commit comments

Comments
 (0)