diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala index 8a9dcb486b7b..aa5a71ed19b1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala @@ -73,7 +73,7 @@ private[spark] class DecisionTreeMetadata( * For ordered features, there is 1 more bin than split. */ def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) { - numBins(featureIndex) + DecisionTreeMetadata.numUnorderedSplits(numBins(featureIndex)) } else { numBins(featureIndex) - 1 } @@ -152,15 +152,13 @@ private[spark] object DecisionTreeMetadata extends Logging { // TODO(SPARK-9957): Handle this properly by filtering out those features. if (numCategories > 1) { // Decide if some categorical features should be treated as unordered features, - // which require 2 * ((1 << numCategories - 1) - 1) bins. + // Both ordered and unordered features require numCategories bins. // We do this check with log values to prevent overflows in case numCategories is large. // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins if (numCategories <= maxCategoriesForUnorderedFeature) { unorderedFeatures.add(featureIndex) - numBins(featureIndex) = numUnorderedBins(numCategories) - } else { - numBins(featureIndex) = numCategories } + numBins(featureIndex) = numCategories } } } else { @@ -226,8 +224,7 @@ private[spark] object DecisionTreeMetadata extends Logging { * return the number of bins for the feature if it is to be treated as an unordered feature. * There is 1 split for every partitioning of categories into 2 disjoint, non-empty sets; * there are math.pow(2, arity - 1) - 1 such splits. - * Each split has 2 corresponding bins. */ - def numUnorderedBins(arity: Int): Int = (1 << arity - 1) - 1 + def numUnorderedSplits(arity: Int): Int = (1 << arity - 1) - 1 } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index acfc6399c553..36aec9cb6fcb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -33,6 +33,7 @@ import org.apache.spark.mllib.tree.impurity.ImpurityCalculator import org.apache.spark.mllib.tree.model.ImpurityStats import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.collection.BitSet import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom} @@ -244,66 +245,7 @@ private[spark] object RandomForest extends Logging { } /** - * Helper for binSeqOp, for data which can contain a mix of ordered and unordered features. - * - * For ordered features, a single bin is updated. - * For unordered features, bins correspond to subsets of categories; either the left or right bin - * for each subset is updated. - * - * @param agg Array storing aggregate calculation, with a set of sufficient statistics for - * each (feature, bin). - * @param treePoint Data point being aggregated. - * @param splits possible splits indexed (numFeatures)(numSplits) - * @param unorderedFeatures Set of indices of unordered features. - * @param instanceWeight Weight (importance) of instance in dataset. - */ - private def mixedBinSeqOp( - agg: DTStatsAggregator, - treePoint: TreePoint, - splits: Array[Array[Split]], - unorderedFeatures: Set[Int], - instanceWeight: Double, - featuresForNode: Option[Array[Int]]): Unit = { - val numFeaturesPerNode = if (featuresForNode.nonEmpty) { - // Use subsampled features - featuresForNode.get.length - } else { - // Use all features - agg.metadata.numFeatures - } - // Iterate over features. - var featureIndexIdx = 0 - while (featureIndexIdx < numFeaturesPerNode) { - val featureIndex = if (featuresForNode.nonEmpty) { - featuresForNode.get.apply(featureIndexIdx) - } else { - featureIndexIdx - } - if (unorderedFeatures.contains(featureIndex)) { - // Unordered feature - val featureValue = treePoint.binnedFeatures(featureIndex) - val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx) - // Update the left or right bin for each split. - val numSplits = agg.metadata.numSplits(featureIndex) - val featureSplits = splits(featureIndex) - var splitIndex = 0 - while (splitIndex < numSplits) { - if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) { - agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) - } - splitIndex += 1 - } - } else { - // Ordered feature - val binIndex = treePoint.binnedFeatures(featureIndex) - agg.update(featureIndexIdx, binIndex, treePoint.label, instanceWeight) - } - featureIndexIdx += 1 - } - } - - /** - * Helper for binSeqOp, for regression and for classification with only ordered features. + * Helper for binSeqOp, for regression and for classification. * * For each feature, the sufficient statistics of one bin are updated. * @@ -312,7 +254,7 @@ private[spark] object RandomForest extends Logging { * @param treePoint Data point being aggregated. * @param instanceWeight Weight (importance) of instance in dataset. */ - private def orderedBinSeqOp( + private def _binSeqOp( agg: DTStatsAggregator, treePoint: TreePoint, instanceWeight: Double, @@ -424,12 +366,7 @@ private[spark] object RandomForest extends Logging { val aggNodeIndex = nodeInfo.nodeIndexInGroup val featuresForNode = nodeInfo.featureSubset val instanceWeight = baggedPoint.subsampleWeights(treeIndex) - if (metadata.unorderedFeatures.isEmpty) { - orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode) - } else { - mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits, - metadata.unorderedFeatures, instanceWeight, featuresForNode) - } + _binSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode) agg(aggNodeIndex).updateParent(baggedPoint.datum.label, instanceWeight) } } @@ -741,17 +678,43 @@ private[spark] object RandomForest extends Logging { (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else if (binAggregates.metadata.isUnordered(featureIndex)) { // Unordered categorical feature - val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx) - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { splitIndex => - val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) - val rightChildStats = binAggregates.getParentImpurityCalculator() - .subtract(leftChildStats) + val numBins = binAggregates.metadata.numBins(featureIndex) + val featureOffset = binAggregates.getFeatureOffset(featureIndexIdx) + + val binStatsArray = Array.tabulate(numBins) { binIndex => + binAggregates.getImpurityCalculator(featureOffset, binIndex) + } + val parentStats = binAggregates.getParentImpurityCalculator() + + var bestGain = Double.NegativeInfinity + var bestSet: BitSet = null + var bestLeftChildStats: ImpurityCalculator = null + var bestRightChildStats: ImpurityCalculator = null + + traverseUnorderedSplits[ImpurityCalculator](numBins, null, + (stats, binIndex) => { + val binStats = binStatsArray(binIndex) + if (stats == null) { + binStats + } else { + stats.copy.add(binStats) + } + }, + (set, leftChildStats) => { + val rightChildStats = parentStats.copy.subtract(leftChildStats) gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, leftChildStats, rightChildStats, binAggregates.metadata) - (splitIndex, gainAndImpurityStats) - }.maxBy(_._2.gain) - (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) + if (gainAndImpurityStats.gain > bestGain) { + bestGain = gainAndImpurityStats.gain + bestSet = set | new BitSet(numBins) // copy set + bestLeftChildStats = leftChildStats + bestRightChildStats = rightChildStats + } + } + ) + val bestSplit = new CategoricalSplit(featureIndex, + bestSet.iterator.map(_.toDouble).toArray, numBins) + (bestSplit, gainAndImpurityStats) } else { // Ordered categorical feature val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) @@ -935,17 +898,8 @@ private[spark] object RandomForest extends Logging { metadata.setNumSplits(i, split.length) split - case i if metadata.isCategorical(i) && metadata.isUnordered(i) => - // Unordered features - // 2^(maxFeatureValue - 1) - 1 combinations - val featureArity = metadata.featureArity(i) - Array.tabulate[Split](metadata.numSplits(i)) { splitIndex => - val categories = extractMultiClassCategories(splitIndex + 1, featureArity) - new CategoricalSplit(i, categories.toArray, featureArity) - } - case i if metadata.isCategorical(i) => - // Ordered features + // Ordered and unordered features // Splits are constructed as needed during training. Array.empty[Split] } @@ -976,6 +930,44 @@ private[spark] object RandomForest extends Logging { categories } + private[tree] def traverseUnorderedSplits[T]( + arity: Int, + zeroStats: T, + seqOp: (T, Int) => T, + finalizer: (BitSet, T) => Unit): Unit = { + assert(arity > 1) + + // numSplits = (1 << arity - 1) - 1 + val numSplits = DecisionTreeMetadata.numUnorderedSplits(arity) + val subSet: BitSet = new BitSet(arity) + + // dfs traverse + // binIndex: [0, arity) + def dfs(binIndex: Int, combNumber: Int, stats: T): Unit = { + if (binIndex == arity) { + // recursion exit when binIndex == arity + if (combNumber > 0) { + // we get an available unordered split, saved in subSet. + finalizer(subSet, stats) + } + } else { + subSet.set(binIndex) + val leftChildCombNumber = combNumber + (1 << binIndex) + // pruning: only need combNumber satisfy: 1 <= combNumber <= numSplits + // and when go into deeper recursion, the combNumber will be monotonically increasing + // so we can stop recursion when combNumber > numSplits + if (leftChildCombNumber <= numSplits) { + val leftChildStats = seqOp(stats, binIndex) + dfs(binIndex + 1, leftChildCombNumber, leftChildStats) + } + subSet.unset(binIndex) + dfs(binIndex + 1, combNumber, stats) + } + } + + dfs(0, 0, zeroStats) + } + /** * Find splits for a continuous feature * NOTE: Returned number of splits is set based on `featureSamples` and diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index dbe2ea931fb9..77e9d11259fc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -249,28 +249,11 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(metadata.isUnordered(featureIndex = 1)) val splits = RandomForest.findSplits(rdd, metadata, seed = 42) assert(splits.length === 2) - assert(splits(0).length === 3) + assert(splits(0).length === 0) assert(metadata.numSplits(0) === 3) assert(metadata.numBins(0) === 3) assert(metadata.numSplits(1) === 3) assert(metadata.numBins(1) === 3) - - // Expecting 2^2 - 1 = 3 splits per feature - def checkCategoricalSplit(s: Split, featureIndex: Int, leftCategories: Array[Double]): Unit = { - assert(s.featureIndex === featureIndex) - assert(s.isInstanceOf[CategoricalSplit]) - val s0 = s.asInstanceOf[CategoricalSplit] - assert(s0.leftCategories === leftCategories) - assert(s0.numCategories === 3) // for this unit test - } - // Feature 0 - checkCategoricalSplit(splits(0)(0), 0, Array(0.0)) - checkCategoricalSplit(splits(0)(1), 0, Array(1.0)) - checkCategoricalSplit(splits(0)(2), 0, Array(0.0, 1.0)) - // Feature 1 - checkCategoricalSplit(splits(1)(0), 1, Array(0.0)) - checkCategoricalSplit(splits(1)(1), 1, Array(1.0)) - checkCategoricalSplit(splits(1)(2), 1, Array(0.0, 1.0)) } test("Multiclass classification with ordered categorical features: split calculations") { @@ -631,6 +614,42 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0) assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01) } + + test("traverseUnorderedSplits") { + + val numBins = 8 + val numSplits = DecisionTreeMetadata.numUnorderedSplits(numBins) + + val resultCheck = Array.fill(numSplits + 1)(false) + + RandomForest.traverseUnorderedSplits[Int](numBins, 0, + (statsVal, binIndex) => statsVal + (1 << binIndex), + (bitSet, statsVal) => { + // We get a combination here, the bitSet mark the bits to be true + // which are in the combination. + // the statsVal is the combNumber: + // e.g. + // suppose get combination [0,0,1,0,1,1,0,1] (binIndex from high to low) + // then the statsVal == the number which binary representation is "00101101" + + // 1. check the combination do not be traversed more than once + assert(resultCheck(statsVal) === false) + resultCheck(statsVal) = true + + // 2. check the combNumber we get is correct. + // e.g combNumber "00101101" (binary format) match the combination stored in + // the bitSet [0,0,1,0,1,1,0,1] + for (i <- 0 until numBins) { + val testBit = (((statsVal >> i) & 1) == 1) + assert(bitSet.get(i) === testBit) + } + } + ) + // 3. check the traverse cover all combinations (total combination number = numSplits) + for (i <- 1 to numSplits) { + assert(resultCheck(i) === true) + } + } } private object RandomForestSuite {