-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-22451][ML] Reduce decision tree aggregate size for unordered features from O(2^numCategories) to O(numCategories) #19666
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,6 +33,7 @@ import org.apache.spark.mllib.tree.impurity.ImpurityCalculator | |
| import org.apache.spark.mllib.tree.model.ImpurityStats | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.storage.StorageLevel | ||
| import org.apache.spark.util.collection.BitSet | ||
| import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom} | ||
|
|
||
|
|
||
|
|
@@ -244,66 +245,7 @@ private[spark] object RandomForest extends Logging { | |
| } | ||
|
|
||
| /** | ||
| * Helper for binSeqOp, for data which can contain a mix of ordered and unordered features. | ||
| * | ||
| * For ordered features, a single bin is updated. | ||
| * For unordered features, bins correspond to subsets of categories; either the left or right bin | ||
| * for each subset is updated. | ||
| * | ||
| * @param agg Array storing aggregate calculation, with a set of sufficient statistics for | ||
| * each (feature, bin). | ||
| * @param treePoint Data point being aggregated. | ||
| * @param splits possible splits indexed (numFeatures)(numSplits) | ||
| * @param unorderedFeatures Set of indices of unordered features. | ||
| * @param instanceWeight Weight (importance) of instance in dataset. | ||
| */ | ||
| private def mixedBinSeqOp( | ||
| agg: DTStatsAggregator, | ||
| treePoint: TreePoint, | ||
| splits: Array[Array[Split]], | ||
| unorderedFeatures: Set[Int], | ||
| instanceWeight: Double, | ||
| featuresForNode: Option[Array[Int]]): Unit = { | ||
| val numFeaturesPerNode = if (featuresForNode.nonEmpty) { | ||
| // Use subsampled features | ||
| featuresForNode.get.length | ||
| } else { | ||
| // Use all features | ||
| agg.metadata.numFeatures | ||
| } | ||
| // Iterate over features. | ||
| var featureIndexIdx = 0 | ||
| while (featureIndexIdx < numFeaturesPerNode) { | ||
| val featureIndex = if (featuresForNode.nonEmpty) { | ||
| featuresForNode.get.apply(featureIndexIdx) | ||
| } else { | ||
| featureIndexIdx | ||
| } | ||
| if (unorderedFeatures.contains(featureIndex)) { | ||
| // Unordered feature | ||
| val featureValue = treePoint.binnedFeatures(featureIndex) | ||
| val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx) | ||
| // Update the left or right bin for each split. | ||
| val numSplits = agg.metadata.numSplits(featureIndex) | ||
| val featureSplits = splits(featureIndex) | ||
| var splitIndex = 0 | ||
| while (splitIndex < numSplits) { | ||
| if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) { | ||
| agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) | ||
| } | ||
| splitIndex += 1 | ||
| } | ||
| } else { | ||
| // Ordered feature | ||
| val binIndex = treePoint.binnedFeatures(featureIndex) | ||
| agg.update(featureIndexIdx, binIndex, treePoint.label, instanceWeight) | ||
| } | ||
| featureIndexIdx += 1 | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Helper for binSeqOp, for regression and for classification with only ordered features. | ||
| * Helper for binSeqOp, for regression and for classification. | ||
| * | ||
| * For each feature, the sufficient statistics of one bin are updated. | ||
| * | ||
|
|
@@ -312,7 +254,7 @@ private[spark] object RandomForest extends Logging { | |
| * @param treePoint Data point being aggregated. | ||
| * @param instanceWeight Weight (importance) of instance in dataset. | ||
| */ | ||
| private def orderedBinSeqOp( | ||
| private def _binSeqOp( | ||
| agg: DTStatsAggregator, | ||
| treePoint: TreePoint, | ||
| instanceWeight: Double, | ||
|
|
@@ -424,12 +366,7 @@ private[spark] object RandomForest extends Logging { | |
| val aggNodeIndex = nodeInfo.nodeIndexInGroup | ||
| val featuresForNode = nodeInfo.featureSubset | ||
| val instanceWeight = baggedPoint.subsampleWeights(treeIndex) | ||
| if (metadata.unorderedFeatures.isEmpty) { | ||
| orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode) | ||
| } else { | ||
| mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits, | ||
| metadata.unorderedFeatures, instanceWeight, featuresForNode) | ||
| } | ||
| _binSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode) | ||
| agg(aggNodeIndex).updateParent(baggedPoint.datum.label, instanceWeight) | ||
| } | ||
| } | ||
|
|
@@ -741,17 +678,43 @@ private[spark] object RandomForest extends Logging { | |
| (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) | ||
| } else if (binAggregates.metadata.isUnordered(featureIndex)) { | ||
| // Unordered categorical feature | ||
| val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx) | ||
| val (bestFeatureSplitIndex, bestFeatureGainStats) = | ||
| Range(0, numSplits).map { splitIndex => | ||
| val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) | ||
| val rightChildStats = binAggregates.getParentImpurityCalculator() | ||
| .subtract(leftChildStats) | ||
| val numBins = binAggregates.metadata.numBins(featureIndex) | ||
| val featureOffset = binAggregates.getFeatureOffset(featureIndexIdx) | ||
|
|
||
| val binStatsArray = Array.tabulate(numBins) { binIndex => | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please add a comment explaining what this is? E.g.: |
||
| binAggregates.getImpurityCalculator(featureOffset, binIndex) | ||
| } | ||
| val parentStats = binAggregates.getParentImpurityCalculator() | ||
|
|
||
| var bestGain = Double.NegativeInfinity | ||
| var bestSet: BitSet = null | ||
| var bestLeftChildStats: ImpurityCalculator = null | ||
| var bestRightChildStats: ImpurityCalculator = null | ||
|
|
||
| traverseUnorderedSplits[ImpurityCalculator](numBins, null, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please add a comment explaining what this does? E.g.: |
||
| (stats, binIndex) => { | ||
| val binStats = binStatsArray(binIndex) | ||
| if (stats == null) { | ||
| binStats | ||
| } else { | ||
| stats.copy.add(binStats) | ||
| } | ||
| }, | ||
| (set, leftChildStats) => { | ||
| val rightChildStats = parentStats.copy.subtract(leftChildStats) | ||
| gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, | ||
| leftChildStats, rightChildStats, binAggregates.metadata) | ||
| (splitIndex, gainAndImpurityStats) | ||
| }.maxBy(_._2.gain) | ||
| (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) | ||
| if (gainAndImpurityStats.gain > bestGain) { | ||
| bestGain = gainAndImpurityStats.gain | ||
| bestSet = set | new BitSet(numBins) // copy set | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not use
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The class do not support |
||
| bestLeftChildStats = leftChildStats | ||
| bestRightChildStats = rightChildStats | ||
| } | ||
| } | ||
| ) | ||
| val bestSplit = new CategoricalSplit(featureIndex, | ||
| bestSet.iterator.map(_.toDouble).toArray, numBins) | ||
| (bestSplit, gainAndImpurityStats) | ||
| } else { | ||
| // Ordered categorical feature | ||
| val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) | ||
|
|
@@ -935,17 +898,8 @@ private[spark] object RandomForest extends Logging { | |
| metadata.setNumSplits(i, split.length) | ||
| split | ||
|
|
||
| case i if metadata.isCategorical(i) && metadata.isUnordered(i) => | ||
| // Unordered features | ||
| // 2^(maxFeatureValue - 1) - 1 combinations | ||
| val featureArity = metadata.featureArity(i) | ||
| Array.tabulate[Split](metadata.numSplits(i)) { splitIndex => | ||
| val categories = extractMultiClassCategories(splitIndex + 1, featureArity) | ||
| new CategoricalSplit(i, categories.toArray, featureArity) | ||
| } | ||
|
|
||
| case i if metadata.isCategorical(i) => | ||
| // Ordered features | ||
| // Ordered and unordered features | ||
| // Splits are constructed as needed during training. | ||
| Array.empty[Split] | ||
| } | ||
|
|
@@ -976,6 +930,44 @@ private[spark] object RandomForest extends Logging { | |
| categories | ||
| } | ||
|
|
||
| private[tree] def traverseUnorderedSplits[T]( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please add a docstring for this method, since it's a bit complicated?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, does |
||
| arity: Int, | ||
| zeroStats: T, | ||
| seqOp: (T, Int) => T, | ||
| finalizer: (BitSet, T) => Unit): Unit = { | ||
| assert(arity > 1) | ||
|
|
||
| // numSplits = (1 << arity - 1) - 1 | ||
| val numSplits = DecisionTreeMetadata.numUnorderedSplits(arity) | ||
| val subSet: BitSet = new BitSet(arity) | ||
|
|
||
| // dfs traverse | ||
| // binIndex: [0, arity) | ||
| def dfs(binIndex: Int, combNumber: Int, stats: T): Unit = { | ||
| if (binIndex == arity) { | ||
| // recursion exit when binIndex == arity | ||
| if (combNumber > 0) { | ||
| // we get an available unordered split, saved in subSet. | ||
| finalizer(subSet, stats) | ||
| } | ||
| } else { | ||
| subSet.set(binIndex) | ||
| val leftChildCombNumber = combNumber + (1 << binIndex) | ||
| // pruning: only need combNumber satisfy: 1 <= combNumber <= numSplits | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I understand correctly, the check
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. for example: "00101" and "11010" they're equivalent splits, we should traverse only one of them. |
||
| // and when go into deeper recursion, the combNumber will be monotonically increasing | ||
| // so we can stop recursion when combNumber > numSplits | ||
| if (leftChildCombNumber <= numSplits) { | ||
| val leftChildStats = seqOp(stats, binIndex) | ||
| dfs(binIndex + 1, leftChildCombNumber, leftChildStats) | ||
| } | ||
| subSet.unset(binIndex) | ||
| dfs(binIndex + 1, combNumber, stats) | ||
| } | ||
| } | ||
|
|
||
| dfs(0, 0, zeroStats) | ||
| } | ||
|
|
||
| /** | ||
| * Find splits for a continuous feature | ||
| * NOTE: Returned number of splits is set based on `featureSamples` and | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -249,28 +249,11 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { | |
| assert(metadata.isUnordered(featureIndex = 1)) | ||
| val splits = RandomForest.findSplits(rdd, metadata, seed = 42) | ||
| assert(splits.length === 2) | ||
| assert(splits(0).length === 3) | ||
| assert(splits(0).length === 0) | ||
| assert(metadata.numSplits(0) === 3) | ||
| assert(metadata.numBins(0) === 3) | ||
| assert(metadata.numSplits(1) === 3) | ||
| assert(metadata.numBins(1) === 3) | ||
|
|
||
| // Expecting 2^2 - 1 = 3 splits per feature | ||
| def checkCategoricalSplit(s: Split, featureIndex: Int, leftCategories: Array[Double]): Unit = { | ||
| assert(s.featureIndex === featureIndex) | ||
| assert(s.isInstanceOf[CategoricalSplit]) | ||
| val s0 = s.asInstanceOf[CategoricalSplit] | ||
| assert(s0.leftCategories === leftCategories) | ||
| assert(s0.numCategories === 3) // for this unit test | ||
| } | ||
| // Feature 0 | ||
| checkCategoricalSplit(splits(0)(0), 0, Array(0.0)) | ||
| checkCategoricalSplit(splits(0)(1), 0, Array(1.0)) | ||
| checkCategoricalSplit(splits(0)(2), 0, Array(0.0, 1.0)) | ||
| // Feature 1 | ||
| checkCategoricalSplit(splits(1)(0), 1, Array(0.0)) | ||
| checkCategoricalSplit(splits(1)(1), 1, Array(1.0)) | ||
| checkCategoricalSplit(splits(1)(2), 1, Array(0.0, 1.0)) | ||
| } | ||
|
|
||
| test("Multiclass classification with ordered categorical features: split calculations") { | ||
|
|
@@ -631,6 +614,42 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { | |
| val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0) | ||
| assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01) | ||
| } | ||
|
|
||
| test("traverseUnorderedSplits") { | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So how to test all possible splits to make sure the generated splits are all correct ? If tree generated, only best split is remained. |
||
| val numBins = 8 | ||
| val numSplits = DecisionTreeMetadata.numUnorderedSplits(numBins) | ||
|
|
||
| val resultCheck = Array.fill(numSplits + 1)(false) | ||
|
|
||
| RandomForest.traverseUnorderedSplits[Int](numBins, 0, | ||
| (statsVal, binIndex) => statsVal + (1 << binIndex), | ||
| (bitSet, statsVal) => { | ||
| // We get a combination here, the bitSet mark the bits to be true | ||
| // which are in the combination. | ||
| // the statsVal is the combNumber: | ||
| // e.g. | ||
| // suppose get combination [0,0,1,0,1,1,0,1] (binIndex from high to low) | ||
| // then the statsVal == the number which binary representation is "00101101" | ||
|
|
||
| // 1. check the combination do not be traversed more than once | ||
| assert(resultCheck(statsVal) === false) | ||
| resultCheck(statsVal) = true | ||
|
|
||
| // 2. check the combNumber we get is correct. | ||
| // e.g combNumber "00101101" (binary format) match the combination stored in | ||
| // the bitSet [0,0,1,0,1,1,0,1] | ||
| for (i <- 0 until numBins) { | ||
| val testBit = (((statsVal >> i) & 1) == 1) | ||
| assert(bitSet.get(i) === testBit) | ||
| } | ||
| } | ||
| ) | ||
| // 3. check the traverse cover all combinations (total combination number = numSplits) | ||
| for (i <- 1 to numSplits) { | ||
| assert(resultCheck(i) === true) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| private object RandomForestSuite { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change , to .