Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ private[spark] class DecisionTreeMetadata(
* For ordered features, there is 1 more bin than split.
*/
def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) {
numBins(featureIndex)
DecisionTreeMetadata.numUnorderedSplits(numBins(featureIndex))
} else {
numBins(featureIndex) - 1
}
Expand Down Expand Up @@ -152,15 +152,13 @@ private[spark] object DecisionTreeMetadata extends Logging {
// TODO(SPARK-9957): Handle this properly by filtering out those features.
if (numCategories > 1) {
// Decide if some categorical features should be treated as unordered features,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change , to .

// which require 2 * ((1 << numCategories - 1) - 1) bins.
// Both ordered and unordered features require numCategories bins.
// We do this check with log values to prevent overflows in case numCategories is large.
// The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins
if (numCategories <= maxCategoriesForUnorderedFeature) {
unorderedFeatures.add(featureIndex)
numBins(featureIndex) = numUnorderedBins(numCategories)
} else {
numBins(featureIndex) = numCategories
}
numBins(featureIndex) = numCategories
}
}
} else {
Expand Down Expand Up @@ -226,8 +224,7 @@ private[spark] object DecisionTreeMetadata extends Logging {
* return the number of bins for the feature if it is to be treated as an unordered feature.
* There is 1 split for every partitioning of categories into 2 disjoint, non-empty sets;
* there are math.pow(2, arity - 1) - 1 such splits.
* Each split has 2 corresponding bins.
*/
def numUnorderedBins(arity: Int): Int = (1 << arity - 1) - 1
def numUnorderedSplits(arity: Int): Int = (1 << arity - 1) - 1

}
164 changes: 78 additions & 86 deletions mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
import org.apache.spark.mllib.tree.model.ImpurityStats
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.collection.BitSet
import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}


Expand Down Expand Up @@ -244,66 +245,7 @@ private[spark] object RandomForest extends Logging {
}

/**
* Helper for binSeqOp, for data which can contain a mix of ordered and unordered features.
*
* For ordered features, a single bin is updated.
* For unordered features, bins correspond to subsets of categories; either the left or right bin
* for each subset is updated.
*
* @param agg Array storing aggregate calculation, with a set of sufficient statistics for
* each (feature, bin).
* @param treePoint Data point being aggregated.
* @param splits possible splits indexed (numFeatures)(numSplits)
* @param unorderedFeatures Set of indices of unordered features.
* @param instanceWeight Weight (importance) of instance in dataset.
*/
private def mixedBinSeqOp(
agg: DTStatsAggregator,
treePoint: TreePoint,
splits: Array[Array[Split]],
unorderedFeatures: Set[Int],
instanceWeight: Double,
featuresForNode: Option[Array[Int]]): Unit = {
val numFeaturesPerNode = if (featuresForNode.nonEmpty) {
// Use subsampled features
featuresForNode.get.length
} else {
// Use all features
agg.metadata.numFeatures
}
// Iterate over features.
var featureIndexIdx = 0
while (featureIndexIdx < numFeaturesPerNode) {
val featureIndex = if (featuresForNode.nonEmpty) {
featuresForNode.get.apply(featureIndexIdx)
} else {
featureIndexIdx
}
if (unorderedFeatures.contains(featureIndex)) {
// Unordered feature
val featureValue = treePoint.binnedFeatures(featureIndex)
val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx)
// Update the left or right bin for each split.
val numSplits = agg.metadata.numSplits(featureIndex)
val featureSplits = splits(featureIndex)
var splitIndex = 0
while (splitIndex < numSplits) {
if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) {
agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight)
}
splitIndex += 1
}
} else {
// Ordered feature
val binIndex = treePoint.binnedFeatures(featureIndex)
agg.update(featureIndexIdx, binIndex, treePoint.label, instanceWeight)
}
featureIndexIdx += 1
}
}

/**
* Helper for binSeqOp, for regression and for classification with only ordered features.
* Helper for binSeqOp, for regression and for classification.
*
* For each feature, the sufficient statistics of one bin are updated.
*
Expand All @@ -312,7 +254,7 @@ private[spark] object RandomForest extends Logging {
* @param treePoint Data point being aggregated.
* @param instanceWeight Weight (importance) of instance in dataset.
*/
private def orderedBinSeqOp(
private def _binSeqOp(
agg: DTStatsAggregator,
treePoint: TreePoint,
instanceWeight: Double,
Expand Down Expand Up @@ -424,12 +366,7 @@ private[spark] object RandomForest extends Logging {
val aggNodeIndex = nodeInfo.nodeIndexInGroup
val featuresForNode = nodeInfo.featureSubset
val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
if (metadata.unorderedFeatures.isEmpty) {
orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
} else {
mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
metadata.unorderedFeatures, instanceWeight, featuresForNode)
}
_binSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
agg(aggNodeIndex).updateParent(baggedPoint.datum.label, instanceWeight)
}
}
Expand Down Expand Up @@ -741,17 +678,43 @@ private[spark] object RandomForest extends Logging {
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else if (binAggregates.metadata.isUnordered(featureIndex)) {
// Unordered categorical feature
val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx)
val (bestFeatureSplitIndex, bestFeatureGainStats) =
Range(0, numSplits).map { splitIndex =>
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
val rightChildStats = binAggregates.getParentImpurityCalculator()
.subtract(leftChildStats)
val numBins = binAggregates.metadata.numBins(featureIndex)
val featureOffset = binAggregates.getFeatureOffset(featureIndexIdx)

val binStatsArray = Array.tabulate(numBins) { binIndex =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add a comment explaining what this is? E.g.:
// Each element of binStatsArray stores pre-computed label statistics for a single bin of the current future

binAggregates.getImpurityCalculator(featureOffset, binIndex)
}
val parentStats = binAggregates.getParentImpurityCalculator()

var bestGain = Double.NegativeInfinity
var bestSet: BitSet = null
var bestLeftChildStats: ImpurityCalculator = null
var bestRightChildStats: ImpurityCalculator = null

traverseUnorderedSplits[ImpurityCalculator](numBins, null,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add a comment explaining what this does? E.g.:
// Computes the best split for the current feature, storing the result across the vars above

(stats, binIndex) => {
val binStats = binStatsArray(binIndex)
if (stats == null) {
binStats
} else {
stats.copy.add(binStats)
}
},
(set, leftChildStats) => {
val rightChildStats = parentStats.copy.subtract(leftChildStats)
gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
leftChildStats, rightChildStats, binAggregates.metadata)
(splitIndex, gainAndImpurityStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
if (gainAndImpurityStats.gain > bestGain) {
bestGain = gainAndImpurityStats.gain
bestSet = set | new BitSet(numBins) // copy set
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use set.copy()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The class do not support copy

bestLeftChildStats = leftChildStats
bestRightChildStats = rightChildStats
}
}
)
val bestSplit = new CategoricalSplit(featureIndex,
bestSet.iterator.map(_.toDouble).toArray, numBins)
(bestSplit, gainAndImpurityStats)
} else {
// Ordered categorical feature
val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
Expand Down Expand Up @@ -935,17 +898,8 @@ private[spark] object RandomForest extends Logging {
metadata.setNumSplits(i, split.length)
split

case i if metadata.isCategorical(i) && metadata.isUnordered(i) =>
// Unordered features
// 2^(maxFeatureValue - 1) - 1 combinations
val featureArity = metadata.featureArity(i)
Array.tabulate[Split](metadata.numSplits(i)) { splitIndex =>
val categories = extractMultiClassCategories(splitIndex + 1, featureArity)
new CategoricalSplit(i, categories.toArray, featureArity)
}

case i if metadata.isCategorical(i) =>
// Ordered features
// Ordered and unordered features
// Splits are constructed as needed during training.
Array.empty[Split]
}
Expand Down Expand Up @@ -976,6 +930,44 @@ private[spark] object RandomForest extends Logging {
categories
}

private[tree] def traverseUnorderedSplits[T](
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add a docstring for this method, since it's a bit complicated?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, does traverseUnorderedSplits need to take a type parameter / two different closures as method arguments? AFAICT the use of a type parameter/closures here allow us to unit test this functionality on a simple example, but I wonder if we could simplify this somehow.

arity: Int,
zeroStats: T,
seqOp: (T, Int) => T,
finalizer: (BitSet, T) => Unit): Unit = {
assert(arity > 1)

// numSplits = (1 << arity - 1) - 1
val numSplits = DecisionTreeMetadata.numUnorderedSplits(arity)
val subSet: BitSet = new BitSet(arity)

// dfs traverse
// binIndex: [0, arity)
def dfs(binIndex: Int, combNumber: Int, stats: T): Unit = {
if (binIndex == arity) {
// recursion exit when binIndex == arity
if (combNumber > 0) {
// we get an available unordered split, saved in subSet.
finalizer(subSet, stats)
}
} else {
subSet.set(binIndex)
val leftChildCombNumber = combNumber + (1 << binIndex)
// pruning: only need combNumber satisfy: 1 <= combNumber <= numSplits
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, the check if (leftChildCombNumber <= numSplits) helps us ensure that we consider each split only once, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. for example: "00101" and "11010" they're equivalent splits, we should traverse only one of them.
So here I use the condition 1 <= combNumber <= numSplits to do the pruning. It can simply filter out another half splits.

// and when go into deeper recursion, the combNumber will be monotonically increasing
// so we can stop recursion when combNumber > numSplits
if (leftChildCombNumber <= numSplits) {
val leftChildStats = seqOp(stats, binIndex)
dfs(binIndex + 1, leftChildCombNumber, leftChildStats)
}
subSet.unset(binIndex)
dfs(binIndex + 1, combNumber, stats)
}
}

dfs(0, 0, zeroStats)
}

/**
* Find splits for a continuous feature
* NOTE: Returned number of splits is set based on `featureSamples` and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,28 +249,11 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(metadata.isUnordered(featureIndex = 1))
val splits = RandomForest.findSplits(rdd, metadata, seed = 42)
assert(splits.length === 2)
assert(splits(0).length === 3)
assert(splits(0).length === 0)
assert(metadata.numSplits(0) === 3)
assert(metadata.numBins(0) === 3)
assert(metadata.numSplits(1) === 3)
assert(metadata.numBins(1) === 3)

// Expecting 2^2 - 1 = 3 splits per feature
def checkCategoricalSplit(s: Split, featureIndex: Int, leftCategories: Array[Double]): Unit = {
assert(s.featureIndex === featureIndex)
assert(s.isInstanceOf[CategoricalSplit])
val s0 = s.asInstanceOf[CategoricalSplit]
assert(s0.leftCategories === leftCategories)
assert(s0.numCategories === 3) // for this unit test
}
// Feature 0
checkCategoricalSplit(splits(0)(0), 0, Array(0.0))
checkCategoricalSplit(splits(0)(1), 0, Array(1.0))
checkCategoricalSplit(splits(0)(2), 0, Array(0.0, 1.0))
// Feature 1
checkCategoricalSplit(splits(1)(0), 1, Array(0.0))
checkCategoricalSplit(splits(1)(1), 1, Array(1.0))
checkCategoricalSplit(splits(1)(2), 1, Array(0.0, 1.0))
}

test("Multiclass classification with ordered categorical features: split calculations") {
Expand Down Expand Up @@ -631,6 +614,42 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
}

test("traverseUnorderedSplits") {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since traverseUnorderedSplits is a private method, I wonder whether we can check the unorder splits on DecisonTree directly? For example, create a tiny dataset and generate a shallow tree (depth = 1?). I know the test case is difficult (maybe impossible) to design, however it focuses on behavior instead of implementation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So how to test all possible splits to make sure the generated splits are all correct ? If tree generated, only best split is remained.

val numBins = 8
val numSplits = DecisionTreeMetadata.numUnorderedSplits(numBins)

val resultCheck = Array.fill(numSplits + 1)(false)

RandomForest.traverseUnorderedSplits[Int](numBins, 0,
(statsVal, binIndex) => statsVal + (1 << binIndex),
(bitSet, statsVal) => {
// We get a combination here, the bitSet mark the bits to be true
// which are in the combination.
// the statsVal is the combNumber:
// e.g.
// suppose get combination [0,0,1,0,1,1,0,1] (binIndex from high to low)
// then the statsVal == the number which binary representation is "00101101"

// 1. check the combination do not be traversed more than once
assert(resultCheck(statsVal) === false)
resultCheck(statsVal) = true

// 2. check the combNumber we get is correct.
// e.g combNumber "00101101" (binary format) match the combination stored in
// the bitSet [0,0,1,0,1,1,0,1]
for (i <- 0 until numBins) {
val testBit = (((statsVal >> i) & 1) == 1)
assert(bitSet.get(i) === testBit)
}
}
)
// 3. check the traverse cover all combinations (total combination number = numSplits)
for (i <- 1 to numSplits) {
assert(resultCheck(i) === true)
}
}
}

private object RandomForestSuite {
Expand Down