@@ -28,7 +28,7 @@ import org.apache.spark.mllib.tree.configuration.Strategy
2828import org .apache .spark .mllib .tree .configuration .Algo ._
2929import org .apache .spark .mllib .tree .configuration .FeatureType ._
3030import org .apache .spark .mllib .tree .configuration .QuantileStrategy ._
31- import org .apache .spark .mllib .tree .impl .{ DecisionTreeMetadata , DTStatsAggregator , TimeTracker , TreePoint }
31+ import org .apache .spark .mllib .tree .impl ._
3232import org .apache .spark .mllib .tree .impurity .{Impurities , Impurity }
3333import org .apache .spark .mllib .tree .impurity ._
3434import org .apache .spark .mllib .tree .model ._
@@ -122,7 +122,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
122122 var break = false
123123 while (level <= maxDepth && ! break) {
124124
125- // println(s"LEVEL $level")
126125 logDebug(" #####################################" )
127126 logDebug(" level = " + level)
128127 logDebug(" #####################################" )
@@ -198,14 +197,12 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
198197
199198 logInfo(" Internal timing for DecisionTree:" )
200199 logInfo(s " $timer" )
201- println(s " $timer" )
202200
203201 new DecisionTreeModel (topNode, strategy.algo)
204202 }
205203
206204}
207205
208-
209206object DecisionTree extends Serializable with Logging {
210207
211208 /**
@@ -456,13 +453,21 @@ object DecisionTree extends Serializable with Logging {
456453 * This function mimics prediction, passing an example from the root node down to a node
457454 * at the current level being trained; that node's index is returned.
458455 *
456+ * @param node Node in tree from which to classify the given data point.
457+ * @param binnedFeatures Binned feature vector for data point.
458+ * @param bins possible bins for all features, indexed (numFeatures)(numBins)
459+ * @param unorderedFeatures Set of indices of unordered features.
459460 * @return Leaf index if the data point reaches a leaf.
460461 * Otherwise, last node reachable in tree matching this example.
461462 * Note: This is the global node index, i.e., the index used in the tree.
462463 * This index is different from the index used during training a particular
463464 * set of nodes in a (level, group).
464465 */
465- def predictNodeIndex (node : Node , binnedFeatures : Array [Int ], bins : Array [Array [Bin ]], unorderedFeatures : Set [Int ]): Int = {
466+ def predictNodeIndex (
467+ node : Node ,
468+ binnedFeatures : Array [Int ],
469+ bins : Array [Array [Bin ]],
470+ unorderedFeatures : Set [Int ]): Int = {
466471 if (node.isLeaf) {
467472 node.id
468473 } else {
@@ -499,15 +504,18 @@ object DecisionTree extends Serializable with Logging {
499504 }
500505
501506 /**
502- * Helper for binSeqOp.
507+ * Helper for binSeqOp, for data containing some unordered (categorical) features .
503508 *
504- * @param agg Array storing aggregate calculation.
505- * For ordered features, this is of size:
506- * numClasses * numBins * numFeatures * numNodes.
507- * For unordered features, this is of size:
508- * 2 * numClasses * numBins * numFeatures * numNodes.
509- * @param treePoint Data point being aggregated.
509+ * For ordered features, a single bin is updated.
510+ * For unordered features, bins correspond to subsets of categories; either the left or right bin
511+ * for each subset is updated.
512+ *
513+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
514+ * each (node, feature, bin).
515+ * @param treePoint Data point being aggregated.
510516 * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
517+ * @param bins possible bins for all features, indexed (numFeatures)(numBins)
518+ * @param unorderedFeatures Set of indices of unordered features.
511519 */
512520 def someUnorderedBinSeqOp (
513521 agg : DTStatsAggregator ,
@@ -547,15 +555,13 @@ object DecisionTree extends Serializable with Logging {
547555 }
548556
549557 /**
550- * Helper for binSeqOp: for regression and for classification with only ordered features.
558+ * Helper for binSeqOp, for regression and for classification with only ordered features.
551559 *
552- * Performs a sequential aggregation over a partition for regression.
553- * For l nodes, k features,
554- * the count, sum, sum of squares of one of the p bins is incremented.
560+ * For each feature, the sufficient statistics of one bin are updated.
555561 *
556- * @param agg Array storing aggregate calculation, updated by this function.
557- * Size: 3 * numBins * numFeatures * numNodes
558- * @param treePoint Data point being aggregated.
562+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
563+ * each (node, feature, bin).
564+ * @param treePoint Data point being aggregated.
559565 * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group).
560566 * @return agg
561567 */
@@ -582,6 +588,7 @@ object DecisionTree extends Serializable with Logging {
582588 * @param parentImpurities Impurities for all parent nodes for the current level
583589 * @param metadata Learning and dataset metadata
584590 * @param level Level of the tree
591+ * @param nodes Array of all nodes in the tree. Used for matching data points to nodes.
585592 * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
586593 * @param bins possible bins for all features, indexed (numFeatures)(numBins)
587594 * @param numGroups total number of node groups at the current level. Default value is set to 1.
@@ -663,19 +670,12 @@ object DecisionTree extends Serializable with Logging {
663670
664671 /**
665672 * Performs a sequential aggregation over a partition.
666- * For l nodes, k features,
667- * For classification:
668- * Either the left count or the right count of one of the bins is
669- * incremented based upon whether the feature is classified as 0 or 1.
670- * For regression:
671- * The count, sum, sum of squares of one of the bins is incremented.
672673 *
673- * @param agg Array storing aggregate calculation, updated by this function.
674- * Size for classification:
675- * Ordered features: numNodes * numFeatures * numBins.
676- * Unordered features: (2 * numNodes) * numFeatures * numBins.
677- * Size for regression:
678- * numNodes * numFeatures * numBins.
674+ * Each data point contributes to one node. For each feature,
675+ * the aggregate sufficient statistics are updated for the relevant bins.
676+ *
677+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics for
678+ * each (node, feature, bin).
679679 * @param treePoint Data point being aggregated.
680680 * @return agg
681681 */
@@ -883,8 +883,10 @@ object DecisionTree extends Serializable with Logging {
883883 val (bestFeatureSplitIndex, bestFeatureGainStats) =
884884 Range (0 , numSplits).map { splitIndex =>
885885 val featureValue = categoriesSortedByCentroid(splitIndex)._1
886- val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
887- val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
886+ val leftChildStats =
887+ binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
888+ val rightChildStats =
889+ binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
888890 rightChildStats.subtract(leftChildStats)
889891 val gainStats =
890892 calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
0 commit comments