diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index 07e98a142b10e..17aba54f21bb3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -276,14 +276,10 @@ private[tree] class LearningNode( new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, leftChild.get.toNode, rightChild.get.toNode, split.get, stats.impurityCalculator) } else { - if (stats.valid) { - new LeafNode(stats.impurityCalculator.predict, stats.impurity, - stats.impurityCalculator) - } else { - // Here we want to keep same behavior with the old mllib.DecisionTreeModel - new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator) - } - + assert(stats != null, "Unknown error during Decision Tree learning. Could not convert " + + "LearningNode to Node") + new LeafNode(stats.impurityCalculator.predict, stats.impurity, + stats.impurityCalculator) } } @@ -334,7 +330,7 @@ private[tree] object LearningNode { id: Int, isLeaf: Boolean, stats: ImpurityStats): LearningNode = { - new LearningNode(id, None, None, None, false, stats) + new LearningNode(id, None, None, None, isLeaf, stats) } /** Create an empty node with the given node index. Values must be set later on. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AggUpdateUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AggUpdateUtils.scala new file mode 100644 index 0000000000000..07e4a16e2990c --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AggUpdateUtils.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.ml.tree.Split + +/** + * Helpers for updating DTStatsAggregators during collection of sufficient stats for tree training. + */ +private[impl] object AggUpdateUtils { + + /** + * Updates the parent node stats of the passed-in impurity aggregator with the labels + * corresponding to the feature values at indices [from, to). + * @param indices Array of row indices for feature values; indices(i) = row index of the ith + * feature value + */ + private[impl] def updateParentImpurity( + statsAggregator: DTStatsAggregator, + indices: Array[Int], + from: Int, + to: Int, + instanceWeights: Array[Double], + labels: Array[Double]): Unit = { + from.until(to).foreach { idx => + val rowIndex = indices(idx) + val label = labels(rowIndex) + statsAggregator.updateParent(label, instanceWeights(rowIndex)) + } + } + + /** + * Update aggregator for an (unordered feature, label) pair + * @param featureSplits Array of splits for the current feature + */ + private[impl] def updateUnorderedFeature( + agg: DTStatsAggregator, + featureValue: Int, + label: Double, + featureIndex: Int, + featureIndexIdx: Int, + featureSplits: Array[Split], + instanceWeight: Double): Unit = { + val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx) + // Each unordered split has a corresponding bin for impurity stats of data points that fall + // onto the left side of the split. For each unordered split, update left-side bin if applicable + // for the current data point. + val numSplits = agg.metadata.numSplits(featureIndex) + var splitIndex = 0 + while (splitIndex < numSplits) { + if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) { + agg.featureUpdate(leftNodeFeatureOffset, splitIndex, label, instanceWeight) + } + splitIndex += 1 + } + } + + /** Update aggregator for an (ordered feature, label) pair */ + private[impl] def updateOrderedFeature( + agg: DTStatsAggregator, + featureValue: Int, + label: Double, + featureIndexIdx: Int, + instanceWeight: Double): Unit = { + // The bin index of an ordered feature is just the feature value itself + val binIndex = featureValue + agg.update(featureIndexIdx, binIndex, label, instanceWeight) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/FeatureColumn.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/FeatureColumn.scala new file mode 100644 index 0000000000000..a403fd7e0bc40 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/FeatureColumn.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.util.collection.BitSet + +/** + * Stores values for a single training data column (a single continuous or categorical feature). + * + * Values are currently stored in a dense representation only. + * TODO: Support sparse storage (to optimize deeper levels of the tree), and maybe compressed + * storage (to optimize upper levels of the tree). + * + * TODO: Sort feature values to support more complicated splitting logic (e.g. considering every + * possible continuous split instead of discretizing continuous features). + * + * TODO: Consider sorting feature values; the only changed required would be to + * sort values at construction-time. Sorting might improve locality during stats + * aggregation (we'd frequently update the same O(statsSize) array for a (feature, bin), + * instead of frequently updating for the same feature). + * + */ +private[impl] class FeatureColumn( + val featureIndex: Int, + val values: Array[Int]) + extends Serializable { + + /** For debugging */ + override def toString: String = { + " FeatureVector(" + + s" featureIndex: $featureIndex,\n" + + s" values: ${values.mkString(", ")},\n" + + " )" + } + + def deepCopy(): FeatureColumn = new FeatureColumn(featureIndex, values.clone()) + + override def equals(other: Any): Boolean = { + other match { + case o: FeatureColumn => + featureIndex == o.featureIndex && values.sameElements(o.values) + case _ => false + } + } + + override def hashCode: Int = { + com.google.common.base.Objects.hashCode( + featureIndex: java.lang.Integer, + values) + } + + /** + * Reorders the subset of feature values at indices [from, to) in the passed-in column + * according to the split information encoded in instanceBitVector (feature values for rows + * that split left appear before feature values for rows that split right). + * + * @param numLeftRows Number of rows on the left side of the split + * @param tempVals Destination buffer for reordered feature values + * @param instanceBitVector instanceBitVector(i) = true if the row for the (from + i)th feature + * value splits right, false otherwise + */ + private[ml] def updateForSplit( + from: Int, + to: Int, + numLeftRows: Int, + tempVals: Array[Int], + instanceBitVector: BitSet): Unit = { + LocalDecisionTreeUtils.updateArrayForSplit(values, from, to, numLeftRows, tempVals, + instanceBitVector) + } +} + +private[impl] object FeatureColumn { + /** + * Store column values sorted by decision tree node (i.e. all column values for a node occur + * in a contiguous subarray). + */ + private[impl] def apply(featureIndex: Int, values: Array[Int]) = { + new FeatureColumn(featureIndex, values) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/ImpurityUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/ImpurityUtils.scala new file mode 100644 index 0000000000000..0dd021eec2473 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/ImpurityUtils.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.mllib.tree.impurity._ +import org.apache.spark.mllib.tree.model.ImpurityStats + +/** Helper methods for impurity-related calculations during node split decisions. */ +private[impl] object ImpurityUtils { + + /** + * Get impurity calculator containing statistics for all labels for rows corresponding to + * feature values in [from, to). + * @param indices indices(i) = row index corresponding to ith feature value + */ + private[impl] def getParentImpurityCalculator( + metadata: DecisionTreeMetadata, + indices: Array[Int], + from: Int, + to: Int, + instanceWeights: Array[Double], + labels: Array[Double]): ImpurityCalculator = { + // Compute sufficient stats (e.g. label counts) for all data at the current node, + // store result in currNodeStatsAgg.parentStats so that we can share it across + // all features for the current node + val currNodeStatsAgg = new DTStatsAggregator(metadata, featureSubset = None) + AggUpdateUtils.updateParentImpurity(currNodeStatsAgg, indices, from, to, + instanceWeights, labels) + currNodeStatsAgg.getParentImpurityCalculator() + } + + /** + * Calculate the impurity statistics for a given (feature, split) based upon left/right + * aggregates. + * + * @param parentImpurityCalculator An ImpurityCalculator containing the impurity stats + * of the node currently being split. + * @param leftImpurityCalculator left node aggregates for this (feature, split) + * @param rightImpurityCalculator right node aggregate for this (feature, split) + * @param metadata learning and dataset metadata for DecisionTree + * @return Impurity statistics for this (feature, split) + */ + private[impl] def calculateImpurityStats( + parentImpurityCalculator: ImpurityCalculator, + leftImpurityCalculator: ImpurityCalculator, + rightImpurityCalculator: ImpurityCalculator, + metadata: DecisionTreeMetadata): ImpurityStats = { + + val impurity: Double = parentImpurityCalculator.calculate() + + val leftCount = leftImpurityCalculator.count + val rightCount = rightImpurityCalculator.count + + val totalCount = leftCount + rightCount + + // If left child or right child doesn't satisfy minimum instances per node, + // then this split is invalid, return invalid information gain stats. + if ((leftCount < metadata.minInstancesPerNode) || + (rightCount < metadata.minInstancesPerNode)) { + return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) + } + + val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 + val rightImpurity = rightImpurityCalculator.calculate() + + val leftWeight = leftCount / totalCount.toDouble + val rightWeight = rightCount / totalCount.toDouble + + val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + // If information gain doesn't satisfy minimum information gain, + // then this split is invalid, return invalid information gain stats. + if (gain < metadata.minInfoGain) { + return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) + } + + // If information gain is non-positive but doesn't violate the minimum info gain constraint, + // return a stats object with correct values but valid = false to indicate that we should not + // split. + if (gain <= 0) { + return new ImpurityStats(gain, impurity, parentImpurityCalculator, leftImpurityCalculator, + rightImpurityCalculator, valid = false) + } + + + new ImpurityStats(gain, impurity, parentImpurityCalculator, + leftImpurityCalculator, rightImpurityCalculator) + } + + /** + * Given an impurity aggregator containing label statistics for a given (node, feature, bin), + * returns the corresponding "centroid", used to order bins while computing best splits. + * + * @param metadata learning and dataset metadata for DecisionTree + */ + private[impl] def getCentroid( + metadata: DecisionTreeMetadata, + binStats: ImpurityCalculator): Double = { + + if (binStats.count != 0) { + if (metadata.isMulticlass) { + // multiclass classification + // For categorical features in multiclass classification, + // the bins are ordered by the impurity of their corresponding labels. + binStats.calculate() + } else if (metadata.isClassification) { + // binary classification + // For categorical features in binary classification, + // the bins are ordered by the count of class 1. + binStats.stats(1) + } else { + // regression + // For categorical features in regression and binary classification, + // the bins are ordered by the prediction. + binStats.predict + } + } else { + Double.MaxValue + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/LocalDecisionTree.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/LocalDecisionTree.scala new file mode 100644 index 0000000000000..eea5418cf8a38 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/LocalDecisionTree.scala @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.ml.tree._ +import org.apache.spark.mllib.tree.model.ImpurityStats + +/** Object exposing methods for local training of decision trees */ +private[ml] object LocalDecisionTree { + + /** + * Fully splits the passed-in node on the provided local dataset, returning + * an InternalNode/LeafNode corresponding to the root of the resulting tree. + * + * @param node LearningNode to use as the root of the subtree fit on the passed-in dataset + * @param metadata learning and dataset metadata for DecisionTree + * @param splits splits(i) = array of splits for feature i + */ + private[ml] def fitNode( + input: Array[TreePoint], + instanceWeights: Array[Double], + node: LearningNode, + metadata: DecisionTreeMetadata, + splits: Array[Array[Split]]): Node = { + + // The case with 1 node (depth = 0) is handled separately. + // This allows all iterations in the depth > 0 case to use the same code. + // TODO: Check that learning works when maxDepth > 0 but learning stops at 1 node (because of + // other parameters). + if (metadata.maxDepth == 0) { + return node.toNode + } + + // Prepare column store. + // Note: rowToColumnStoreDense checks to make sure numRows < Int.MaxValue. + val colStoreInit: Array[Array[Int]] + = LocalDecisionTreeUtils.rowToColumnStoreDense(input.map(_.binnedFeatures)) + val labels = input.map(_.label) + + // Fit a regression model on the dataset, throwing an error if metadata indicates that + // we should train a classifier. + // TODO: Add support for training classifiers + if (metadata.numClasses > 1 && metadata.numClasses <= 32) { + throw new UnsupportedOperationException("Local training of a decision tree classifier is " + + "unsupported; currently, only regression is supported") + } else { + trainRegressor(node, colStoreInit, instanceWeights, labels, metadata, splits) + } + } + + /** + * Locally fits a decision tree regressor. + * TODO(smurching): Logic for fitting a classifier & regressor is the same; only difference + * is impurity metric. Use the same logic for fitting a classifier. + * + * @param rootNode Node to use as root of the tree fit on the passed-in dataset + * @param colStoreInit Array of columns of training data + * @param instanceWeights Array of weights for each training example + * @param metadata learning and dataset metadata for DecisionTree + * @param splits splits(i) = Array of possible splits for feature i + * @return LeafNode or InternalNode representation of rootNode + */ + private[ml] def trainRegressor( + rootNode: LearningNode, + colStoreInit: Array[Array[Int]], + instanceWeights: Array[Double], + labels: Array[Double], + metadata: DecisionTreeMetadata, + splits: Array[Array[Split]]): Node = { + + // Sort each column by decision tree node. + val colStore: Array[FeatureColumn] = colStoreInit.zipWithIndex.map { case (col, featureIndex) => + val featureArity: Int = metadata.featureArity.getOrElse(featureIndex, 0) + FeatureColumn(featureIndex, col) + } + + val numRows = colStore.headOption match { + case None => 0 + case Some(column) => column.values.length + } + + // Create a new TrainingInfo describing the status of our partially-trained subtree + // at each iteration of training + var trainingInfo: TrainingInfo = TrainingInfo(colStore, + nodeOffsets = Array[(Int, Int)]((0, numRows)), currentLevelActiveNodes = Array(rootNode)) + + // Iteratively learn, one level of the tree at a time. + // Note: We do not use node IDs. + var currentLevel = 0 + var doneLearning = false + + while (currentLevel < metadata.maxDepth && !doneLearning) { + // Splits each active node if possible, returning an array of new active nodes + val nextLevelNodes: Array[LearningNode] = + computeBestSplits(trainingInfo, instanceWeights, labels, metadata, splits) + // Count number of non-leaf nodes in the next level + val estimatedRemainingActive = nextLevelNodes.count(!_.isLeaf) + // TODO: Check to make sure we split something, and stop otherwise. + doneLearning = currentLevel + 1 >= metadata.maxDepth || estimatedRemainingActive == 0 + if (!doneLearning) { + // Obtain a new trainingInfo instance describing our current training status + trainingInfo = trainingInfo.update(splits, nextLevelNodes) + } + currentLevel += 1 + } + + // Done with learning + rootNode.toNode + } + + /** + * Iterate over feature values and labels for a specific (node, feature), updating stats + * aggregator for the current node. + */ + private[impl] def updateAggregator( + statsAggregator: DTStatsAggregator, + col: FeatureColumn, + indices: Array[Int], + instanceWeights: Array[Double], + labels: Array[Double], + from: Int, + to: Int, + featureIndexIdx: Int, + featureSplits: Array[Split]): Unit = { + val metadata = statsAggregator.metadata + if (metadata.isUnordered(col.featureIndex)) { + from.until(to).foreach { idx => + val rowIndex = indices(idx) + AggUpdateUtils.updateUnorderedFeature(statsAggregator, col.values(idx), labels(rowIndex), + featureIndex = col.featureIndex, featureIndexIdx, featureSplits, + instanceWeight = instanceWeights(rowIndex)) + } + } else { + from.until(to).foreach { idx => + val rowIndex = indices(idx) + AggUpdateUtils.updateOrderedFeature(statsAggregator, col.values(idx), labels(rowIndex), + featureIndexIdx, instanceWeight = instanceWeights(rowIndex)) + } + } + } + + /** + * Find the best splits for all active nodes + * + * @param trainingInfo Contains node offset info for current set of active nodes + * @return Array of new active nodes formed by splitting the current set of active nodes. + */ + private def computeBestSplits( + trainingInfo: TrainingInfo, + instanceWeights: Array[Double], + labels: Array[Double], + metadata: DecisionTreeMetadata, + splits: Array[Array[Split]]): Array[LearningNode] = { + // For each node, select the best split across all features + trainingInfo match { + case TrainingInfo(columns: Array[FeatureColumn], nodeOffsets: Array[(Int, Int)], + currentLevelActiveNodes: Array[LearningNode], _) => { + // Filter out leaf nodes from the previous iteration + val activeNonLeafs = currentLevelActiveNodes.zipWithIndex.filterNot(_._1.isLeaf) + // Iterate over the active nodes in the current level. + activeNonLeafs.flatMap { case (node: LearningNode, nodeIndex: Int) => + // Features for the current node start at fromOffset and end at toOffset + val (from, to) = nodeOffsets(nodeIndex) + // Get impurityCalculator containing label stats for all data points at the current node + val parentImpurityCalc = ImpurityUtils.getParentImpurityCalculator(metadata, + trainingInfo.indices, from, to, instanceWeights, labels) + val validFeatureSplits = RandomForest.getFeaturesWithSplits(metadata, + featuresForNode = None) + // Find the best split for each feature for the current node + val splitsAndImpurityInfo = validFeatureSplits.map { case (_, featureIndex) => + val col = columns(featureIndex) + // Create a DTStatsAggregator to hold label statistics for each bin of the current + // feature & compute said label statistics + val statsAggregator = new DTStatsAggregator(metadata, Some(Array(featureIndex))) + updateAggregator(statsAggregator, col, trainingInfo.indices, instanceWeights, + labels, from, to, featureIndexIdx = 0, splits(col.featureIndex)) + // Choose best split for current feature based on label statistics + SplitUtils.chooseSplit(statsAggregator, featureIndex, featureIndexIdx = 0, + splits(featureIndex), Some(parentImpurityCalc)) + } + // Find the best split overall (across all features) for the current node + val (bestSplit, bestStats) = RandomForest.getBestSplitByGain(parentImpurityCalc, metadata, + featuresForNode = None, splitsAndImpurityInfo) + // Split current node, get an iterator over its children + splitIfPossible(node, metadata, bestStats, bestSplit) + } + } + } + } + + /** + * Splits the passed-in node if permitted by the parameters of the learning algorithm, + * returning an iterator over its children. Returns an empty array if node could not be split. + * + * @param metadata learning and dataset metadata for DecisionTree + * @param stats Label impurity stats associated with the current node + */ + private[impl] def splitIfPossible( + node: LearningNode, + metadata: DecisionTreeMetadata, + stats: ImpurityStats, + split: Split): Iterator[LearningNode] = { + if (stats.valid) { + // Split node and return an iterator over its children; we filter out leaf nodes later + doSplit(node, split, stats) + Iterator(node.leftChild.get, node.rightChild.get) + } else { + node.stats = stats + node.isLeaf = true + Iterator() + } + } + + /** + * Splits the passed-in node. This method returns nothing, but modifies the passed-in node + * by updating its split and stats members. + * + * @param split Split to associate with the passed-in node + * @param stats Label impurity statistics to associate with the passed-in node + */ + private[impl] def doSplit( + node: LearningNode, + split: Split, + stats: ImpurityStats): Unit = { + val leftChildIsLeaf = stats.leftImpurity == 0 + node.leftChild = Some(LearningNode(id = LearningNode.leftChildIndex(node.id), + isLeaf = leftChildIsLeaf, + ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator))) + val rightChildIsLeaf = stats.rightImpurity == 0 + node.rightChild = Some(LearningNode(id = LearningNode.rightChildIndex(node.id), + isLeaf = rightChildIsLeaf, + ImpurityStats.getEmptyImpurityStats(stats.rightImpurityCalculator) + )) + node.split = Some(split) + node.isLeaf = false + node.stats = stats + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/LocalDecisionTreeUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/LocalDecisionTreeUtils.scala new file mode 100644 index 0000000000000..9ae7951d40f5a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/LocalDecisionTreeUtils.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.internal.Logging +import org.apache.spark.util.collection.BitSet + +/** + * Utility methods specific to local decision tree training. + */ +private[ml] object LocalDecisionTreeUtils extends Logging { + + /** + * Convert a dataset of binned feature values from row storage to column storage. + * Stores data as [[org.apache.spark.ml.linalg.DenseVector]]. + * + * + * @param rowStore An array of input data rows, each represented as an + * int array of binned feature values + * @return Transpose of rowStore as an array of columns consisting of binned feature values. + * + * TODO: Add implementation for sparse data. + * For sparse data, distribute more evenly based on number of non-zeros. + * (First collect stats to decide how to partition.) + */ + private[impl] def rowToColumnStoreDense(rowStore: Array[Array[Int]]): Array[Array[Int]] = { + // Compute the number of rows in the data + val numRows = { + val longNumRows: Long = rowStore.length + require(longNumRows < Int.MaxValue, s"rowToColumnStore given RDD with $longNumRows rows," + + s" but can handle at most ${Int.MaxValue} rows") + longNumRows.toInt + } + + // Check that the input dataset isn't empty (0 rows) or featureless (rows with 0 features) + require(numRows > 0, "Local decision tree training requires numRows > 0.") + val numFeatures = rowStore(0).length + require(numFeatures > 0, "Local decision tree training requires numFeatures > 0.") + // Return the transpose of the rowStore matrix + rowStore.transpose + } + + /** + * Reorders the subset of array values at indices [from, to) + * according to the split information encoded in instanceBitVector (values for rows + * that split left appear before feature values for rows that split right). + * + * @param numLeftRows Number of rows on the left side of the split + * @param tempVals Destination buffer for reordered feature values + * @param instanceBitVector instanceBitVector(i) = true if the row corresponding to the + * (from + i)th array value splits right, false otherwise + */ + private[ml] def updateArrayForSplit( + values: Array[Int], + from: Int, + to: Int, + numLeftRows: Int, + tempVals: Array[Int], + instanceBitVector: BitSet): Unit = { + + // BEGIN SORTING + // We sort the [from, to) slice of col based on instance bit. + // All instances going "left" in the split (which are false) + // should be ordered before the instances going "right". The instanceBitVector + // gives us the split bit value for each instance based on the instance's index. + // We copy our feature values into @tempVals and @tempIndices either: + // 1) in the [from, numLeftRows) range if the bit is false, or + // 2) in the [numLeftRows, to) range if the bit is true. + var (leftInstanceIdx, rightInstanceIdx) = (0, numLeftRows) + var idx = from + while (idx < to) { + val bit = instanceBitVector.get(idx - from) + if (bit) { + tempVals(rightInstanceIdx) = values(idx) + rightInstanceIdx += 1 + } else { + tempVals(leftInstanceIdx) = values(idx) + leftInstanceIdx += 1 + } + idx += 1 + } + // END SORTING + // update the column values and indices + // with the corresponding indices + System.arraycopy(tempVals, 0, values, from, to - from) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index acfc6399c553b..f8c3dd7ff2e7d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.tree.impl import java.io.IOException -import scala.collection.mutable +import scala.collection.{mutable, SeqView} import scala.util.Random import org.apache.spark.internal.Logging @@ -280,23 +280,14 @@ private[spark] object RandomForest extends Logging { featureIndexIdx } if (unorderedFeatures.contains(featureIndex)) { - // Unordered feature - val featureValue = treePoint.binnedFeatures(featureIndex) - val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx) - // Update the left or right bin for each split. - val numSplits = agg.metadata.numSplits(featureIndex) - val featureSplits = splits(featureIndex) - var splitIndex = 0 - while (splitIndex < numSplits) { - if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) { - agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) - } - splitIndex += 1 - } + AggUpdateUtils.updateUnorderedFeature(agg, + featureValue = treePoint.binnedFeatures(featureIndex), label = treePoint.label, + featureIndex = featureIndex, featureIndexIdx = featureIndexIdx, + featureSplits = splits(featureIndex), instanceWeight = instanceWeight) } else { - // Ordered feature - val binIndex = treePoint.binnedFeatures(featureIndex) - agg.update(featureIndexIdx, binIndex, treePoint.label, instanceWeight) + AggUpdateUtils.updateOrderedFeature(agg, + featureValue = treePoint.binnedFeatures(featureIndex), label = treePoint.label, + featureIndexIdx = featureIndexIdx, instanceWeight = instanceWeight) } featureIndexIdx += 1 } @@ -550,6 +541,7 @@ private[spark] object RandomForest extends Logging { } } + // Aggregate sufficient stats by node, then find best splits val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b)).map { case (nodeIndex, aggStats) => val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => @@ -558,12 +550,13 @@ private[spark] object RandomForest extends Logging { // find best split for each node val (split: Split, stats: ImpurityStats) = - binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex)) + RandomForest.binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex)) (nodeIndex, (split, stats)) }.collectAsMap() timer.stop("chooseSplits") + // Perform splits val nodeIdUpdaters = if (nodeIdCache.nonEmpty) { Array.fill[mutable.Map[Int, NodeIndexUpdater]]( metadata.numTrees)(mutable.Map[Int, NodeIndexUpdater]()) @@ -627,221 +620,38 @@ private[spark] object RandomForest extends Logging { } /** - * Calculate the impurity statistics for a given (feature, split) based upon left/right - * aggregates. - * - * @param stats the recycle impurity statistics for this feature's all splits, - * only 'impurity' and 'impurityCalculator' are valid between each iteration - * @param leftImpurityCalculator left node aggregates for this (feature, split) - * @param rightImpurityCalculator right node aggregate for this (feature, split) - * @param metadata learning and dataset metadata for DecisionTree - * @return Impurity statistics for this (feature, split) + * Return a list of pairs (featureIndexIdx, featureIndex) where featureIndex is the global + * (across all trees) index of a feature and featureIndexIdx is the index of a feature within the + * list of features for a given node. Filters out features known to be constant + * (features with 0 splits) */ - private def calculateImpurityStats( - stats: ImpurityStats, - leftImpurityCalculator: ImpurityCalculator, - rightImpurityCalculator: ImpurityCalculator, - metadata: DecisionTreeMetadata): ImpurityStats = { - - val parentImpurityCalculator: ImpurityCalculator = if (stats == null) { - leftImpurityCalculator.copy.add(rightImpurityCalculator) - } else { - stats.impurityCalculator - } - - val impurity: Double = if (stats == null) { - parentImpurityCalculator.calculate() - } else { - stats.impurity - } - - val leftCount = leftImpurityCalculator.count - val rightCount = rightImpurityCalculator.count - - val totalCount = leftCount + rightCount - - // If left child or right child doesn't satisfy minimum instances per node, - // then this split is invalid, return invalid information gain stats. - if ((leftCount < metadata.minInstancesPerNode) || - (rightCount < metadata.minInstancesPerNode)) { - return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) - } - - val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 - val rightImpurity = rightImpurityCalculator.calculate() - - val leftWeight = leftCount / totalCount.toDouble - val rightWeight = rightCount / totalCount.toDouble - - val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - - // if information gain doesn't satisfy minimum information gain, - // then this split is invalid, return invalid information gain stats. - if (gain < metadata.minInfoGain) { - return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) + private[impl] def getFeaturesWithSplits( + metadata: DecisionTreeMetadata, + featuresForNode: Option[Array[Int]]): SeqView[(Int, Int), Seq[_]] = { + Range(0, metadata.numFeaturesPerNode).view.map { featureIndexIdx => + featuresForNode.map(features => (featureIndexIdx, features(featureIndexIdx))) + .getOrElse((featureIndexIdx, featureIndexIdx)) + }.withFilter { case (_, featureIndex) => + metadata.numSplits(featureIndex) != 0 } - - new ImpurityStats(gain, impurity, parentImpurityCalculator, - leftImpurityCalculator, rightImpurityCalculator) } - /** - * Find the best split for a node. - * - * @param binAggregates Bin statistics. - * @return tuple for best split: (Split, information gain, prediction at node) - */ - private[tree] def binsToBestSplit( - binAggregates: DTStatsAggregator, - splits: Array[Array[Split]], + private[impl] def getBestSplitByGain( + parentImpurityCalculator: ImpurityCalculator, + metadata: DecisionTreeMetadata, featuresForNode: Option[Array[Int]], - node: LearningNode): (Split, ImpurityStats) = { - - // Calculate InformationGain and ImpurityStats if current node is top node - val level = LearningNode.indexToLevel(node.id) - var gainAndImpurityStats: ImpurityStats = if (level == 0) { - null - } else { - node.stats - } - - val validFeatureSplits = - Range(0, binAggregates.metadata.numFeaturesPerNode).view.map { featureIndexIdx => - featuresForNode.map(features => (featureIndexIdx, features(featureIndexIdx))) - .getOrElse((featureIndexIdx, featureIndexIdx)) - }.withFilter { case (_, featureIndex) => - binAggregates.metadata.numSplits(featureIndex) != 0 - } - - // For each (feature, split), calculate the gain, and select the best (feature, split). - val splitsAndImpurityInfo = - validFeatureSplits.map { case (featureIndexIdx, featureIndex) => - val numSplits = binAggregates.metadata.numSplits(featureIndex) - if (binAggregates.metadata.isContinuous(featureIndex)) { - // Cumulative sum (scanLeft) of bin statistics. - // Afterwards, binAggregates for a bin is the sum of aggregates for - // that bin + all preceding bins. - val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) - var splitIndex = 0 - while (splitIndex < numSplits) { - binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) - splitIndex += 1 - } - // Find best split. - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { case splitIdx => - val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) - val rightChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) - rightChildStats.subtract(leftChildStats) - gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, - leftChildStats, rightChildStats, binAggregates.metadata) - (splitIdx, gainAndImpurityStats) - }.maxBy(_._2.gain) - (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) - } else if (binAggregates.metadata.isUnordered(featureIndex)) { - // Unordered categorical feature - val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx) - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { splitIndex => - val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) - val rightChildStats = binAggregates.getParentImpurityCalculator() - .subtract(leftChildStats) - gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, - leftChildStats, rightChildStats, binAggregates.metadata) - (splitIndex, gainAndImpurityStats) - }.maxBy(_._2.gain) - (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) - } else { - // Ordered categorical feature - val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) - val numCategories = binAggregates.metadata.numBins(featureIndex) - - /* Each bin is one category (feature value). - * The bins are ordered based on centroidForCategories, and this ordering determines which - * splits are considered. (With K categories, we consider K - 1 possible splits.) - * - * centroidForCategories is a list: (category, centroid) - */ - val centroidForCategories = Range(0, numCategories).map { case featureValue => - val categoryStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) - val centroid = if (categoryStats.count != 0) { - if (binAggregates.metadata.isMulticlass) { - // multiclass classification - // For categorical variables in multiclass classification, - // the bins are ordered by the impurity of their corresponding labels. - categoryStats.calculate() - } else if (binAggregates.metadata.isClassification) { - // binary classification - // For categorical variables in binary classification, - // the bins are ordered by the count of class 1. - categoryStats.stats(1) - } else { - // regression - // For categorical variables in regression and binary classification, - // the bins are ordered by the prediction. - categoryStats.predict - } - } else { - Double.MaxValue - } - (featureValue, centroid) - } - - logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(",")) - - // bins sorted by centroids - val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2) - - logDebug("Sorted centroids for categorical variable = " + - categoriesSortedByCentroid.mkString(",")) - - // Cumulative sum (scanLeft) of bin statistics. - // Afterwards, binAggregates for a bin is the sum of aggregates for - // that bin + all preceding bins. - var splitIndex = 0 - while (splitIndex < numSplits) { - val currentCategory = categoriesSortedByCentroid(splitIndex)._1 - val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1 - binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory) - splitIndex += 1 - } - // lastCategory = index of bin with total aggregates for this (node, feature) - val lastCategory = categoriesSortedByCentroid.last._1 - // Find best split. - val (bestFeatureSplitIndex, bestFeatureGainStats) = - Range(0, numSplits).map { splitIndex => - val featureValue = categoriesSortedByCentroid(splitIndex)._1 - val leftChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) - val rightChildStats = - binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) - rightChildStats.subtract(leftChildStats) - gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, - leftChildStats, rightChildStats, binAggregates.metadata) - (splitIndex, gainAndImpurityStats) - }.maxBy(_._2.gain) - val categoriesForSplit = - categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) - val bestFeatureSplit = - new CategoricalSplit(featureIndex, categoriesForSplit.toArray, numCategories) - (bestFeatureSplit, bestFeatureGainStats) - } - } - + splitsAndImpurityInfo: Seq[(Split, ImpurityStats)]): (Split, ImpurityStats) = { val (bestSplit, bestSplitStats) = if (splitsAndImpurityInfo.isEmpty) { // If no valid splits for features, then this split is invalid, // return invalid information gain stats. Take any split and continue. // Splits is empty, so arbitrarily choose to split on any threshold val dummyFeatureIndex = featuresForNode.map(_.head).getOrElse(0) - val parentImpurityCalculator = binAggregates.getParentImpurityCalculator() - if (binAggregates.metadata.isContinuous(dummyFeatureIndex)) { + if (metadata.isContinuous(dummyFeatureIndex)) { (new ContinuousSplit(dummyFeatureIndex, 0), ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)) } else { - val numCategories = binAggregates.metadata.featureArity(dummyFeatureIndex) + val numCategories = metadata.featureArity(dummyFeatureIndex) (new CategoricalSplit(dummyFeatureIndex, Array(), numCategories), ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)) } @@ -851,6 +661,41 @@ private[spark] object RandomForest extends Logging { (bestSplit, bestSplitStats) } + /** + * Find the best split for a node. + * + * @param binAggregates Bin statistics. + * @return tuple for best split: (Split, information gain, prediction at node) + */ + private[tree] def binsToBestSplit( + binAggregates: DTStatsAggregator, + splits: Array[Array[Split]], + featuresForNode: Option[Array[Int]], + node: LearningNode): (Split, ImpurityStats) = { + val validFeatureSplits = getFeaturesWithSplits(binAggregates.metadata, featuresForNode) + // For each (feature, split), calculate the gain, and select the best (feature, split). + val parentImpurityCalc = if (node.stats == null) None else Some(node.stats.impurityCalculator) + val splitsAndImpurityInfo = + validFeatureSplits.map { case (featureIndexIdx, featureIndex) => + SplitUtils.chooseSplit(binAggregates, featureIndex, featureIndexIdx, splits(featureIndex), + parentImpurityCalc) + } + getBestSplitByGain(binAggregates.getParentImpurityCalculator(), binAggregates.metadata, + featuresForNode, splitsAndImpurityInfo) + } + + private[impl] def findUnorderedSplits( + metadata: DecisionTreeMetadata, + featureIndex: Int): Array[Split] = { + // Unordered features + // 2^(maxFeatureValue - 1) - 1 combinations + val featureArity = metadata.featureArity(featureIndex) + Array.tabulate[Split](metadata.numSplits(featureIndex)) { splitIndex => + val categories = extractMultiClassCategories(splitIndex + 1, featureArity) + new CategoricalSplit(featureIndex, categories.toArray, featureArity) + } + } + /** * Returns splits for decision tree calculation. * Continuous and categorical features are handled differently. @@ -936,13 +781,7 @@ private[spark] object RandomForest extends Logging { split case i if metadata.isCategorical(i) && metadata.isUnordered(i) => - // Unordered features - // 2^(maxFeatureValue - 1) - 1 combinations - val featureArity = metadata.featureArity(i) - Array.tabulate[Split](metadata.numSplits(i)) { splitIndex => - val categories = extractMultiClassCategories(splitIndex + 1, featureArity) - new CategoricalSplit(i, categories.toArray, featureArity) - } + findUnorderedSplits(metadata, i) case i if metadata.isCategorical(i) => // Ordered features @@ -1147,4 +986,5 @@ private[spark] object RandomForest extends Logging { 3 * totalBins } } + } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/SplitUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/SplitUtils.scala new file mode 100644 index 0000000000000..206405a69305c --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/SplitUtils.scala @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.internal.Logging +import org.apache.spark.ml.tree.{CategoricalSplit, Split} +import org.apache.spark.mllib.tree.impurity.ImpurityCalculator +import org.apache.spark.mllib.tree.model.ImpurityStats + +/** Utility methods for choosing splits during local & distributed tree training. */ +private[impl] object SplitUtils extends Logging { + + /** Sorts ordered feature categories by label centroid, returning an ordered list of categories */ + private def sortByCentroid( + binAggregates: DTStatsAggregator, + featureIndex: Int, + featureIndexIdx: Int): List[Int] = { + /* Each bin is one category (feature value). + * The bins are ordered based on centroidForCategories, and this ordering determines which + * splits are considered. (With K categories, we consider K - 1 possible splits.) + * + * centroidForCategories is a list: (category, centroid) + */ + val numCategories = binAggregates.metadata.numBins(featureIndex) + val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) + + val centroidForCategories = Range(0, numCategories).map { featureValue => + val categoryStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val centroid = ImpurityUtils.getCentroid(binAggregates.metadata, categoryStats) + (featureValue, centroid) + } + logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(",")) + // bins sorted by centroids + val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2).map(_._1) + logDebug("Sorted centroids for categorical variable = " + + categoriesSortedByCentroid.mkString(",")) + categoriesSortedByCentroid + } + + /** + * Find the best split for an unordered categorical feature at a single node. + * + * Algorithm: + * - Considers all possible subsets (exponentially many) + * + * @param featureIndex Global index of feature being split. + * @param featureIndexIdx Index of feature being split within subset of features for current node. + * @param featureSplits Array of splits for the current feature + * @param parentCalculator Optional: ImpurityCalculator containing impurity stats for current node + * @return (best split, statistics for split) If no valid split was found, the returned + * ImpurityStats instance will be invalid (have member valid = false). + */ + private[impl] def chooseUnorderedCategoricalSplit( + binAggregates: DTStatsAggregator, + featureIndex: Int, + featureIndexIdx: Int, + featureSplits: Array[Split], + parentCalculator: Option[ImpurityCalculator] = None): (Split, ImpurityStats) = { + // Unordered categorical feature + val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) + val numSplits = binAggregates.metadata.numSplits(featureIndex) + val parentCalc = parentCalculator.getOrElse(binAggregates.getParentImpurityCalculator()) + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { splitIndex => + val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIndex) + val rightChildStats = binAggregates.getParentImpurityCalculator() + .subtract(leftChildStats) + val gainAndImpurityStats = ImpurityUtils.calculateImpurityStats(parentCalc, + leftChildStats, rightChildStats, binAggregates.metadata) + (splitIndex, gainAndImpurityStats) + }.maxBy(_._2.gain) + (featureSplits(bestFeatureSplitIndex), bestFeatureGainStats) + + } + + /** + * Choose splitting rule: feature value <= threshold + * + * @return (best split, statistics for split) If the best split actually puts all instances + * in one leaf node, then it will be set to None. If no valid split was found, the + * returned ImpurityStats instance will be invalid (have member valid = false) + */ + private[impl] def chooseContinuousSplit( + binAggregates: DTStatsAggregator, + featureIndex: Int, + featureIndexIdx: Int, + featureSplits: Array[Split], + parentCalculator: Option[ImpurityCalculator] = None): (Split, ImpurityStats) = { + // For a continuous feature, bins are already sorted for splitting + // Number of "categories" = number of bins + val sortedCategories = Range(0, binAggregates.metadata.numBins(featureIndex)).toList + // Get & return best split info + val (bestFeatureSplitIndex, bestFeatureGainStats) = orderedSplitHelper(binAggregates, + featureIndex, featureIndexIdx, sortedCategories, parentCalculator) + (featureSplits(bestFeatureSplitIndex), bestFeatureGainStats) + } + + /** + * Computes the index of the best split for an ordered feature. + * @param parentCalculator Optional: ImpurityCalculator containing impurity stats for current node + */ + private def orderedSplitHelper( + binAggregates: DTStatsAggregator, + featureIndex: Int, + featureIndexIdx: Int, + categoriesSortedByCentroid: List[Int], + parentCalculator: Option[ImpurityCalculator]): (Int, ImpurityStats) = { + // Cumulative sum (scanLeft) of bin statistics. + // Afterwards, binAggregates for a bin is the sum of aggregates for + // that bin + all preceding bins. + val numSplits = binAggregates.metadata.numSplits(featureIndex) + val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) + var splitIndex = 0 + while (splitIndex < numSplits) { + val currentCategory = categoriesSortedByCentroid(splitIndex) + val nextCategory = categoriesSortedByCentroid(splitIndex + 1) + binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory) + splitIndex += 1 + } + // lastCategory = index of bin with total aggregates for this (node, feature) + val lastCategory = categoriesSortedByCentroid.last + + // Find best split. + val parentCalc = parentCalculator.getOrElse(binAggregates.getParentImpurityCalculator()) + Range(0, numSplits).map { splitIndex => + val featureValue = categoriesSortedByCentroid(splitIndex) + val leftChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val rightChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) + rightChildStats.subtract(leftChildStats) + val gainAndImpurityStats = ImpurityUtils.calculateImpurityStats(parentCalc, + leftChildStats, rightChildStats, binAggregates.metadata) + (splitIndex, gainAndImpurityStats) + }.maxBy(_._2.gain) + } + + /** + * Choose the best split for an ordered categorical feature. + * @param parentCalculator Optional: ImpurityCalculator containing impurity stats for current node + */ + private[impl] def chooseOrderedCategoricalSplit( + binAggregates: DTStatsAggregator, + featureIndex: Int, + featureIndexIdx: Int, + parentCalculator: Option[ImpurityCalculator] = None): (Split, ImpurityStats) = { + // Sort feature categories by label centroid + val categoriesSortedByCentroid = sortByCentroid(binAggregates, featureIndex, featureIndexIdx) + // Get index, stats of best split + val (bestFeatureSplitIndex, bestFeatureGainStats) = orderedSplitHelper(binAggregates, + featureIndex, featureIndexIdx, categoriesSortedByCentroid, parentCalculator) + // Create result (CategoricalSplit instance) + val categoriesForSplit = + categoriesSortedByCentroid.map(_.toDouble).slice(0, bestFeatureSplitIndex + 1) + val numCategories = binAggregates.metadata.featureArity(featureIndex) + val bestFeatureSplit = + new CategoricalSplit(featureIndex, categoriesForSplit.toArray, numCategories) + (bestFeatureSplit, bestFeatureGainStats) + } + + /** + * Choose the best split for a feature at a node. + * + * @param parentCalculator Optional: ImpurityCalculator containing impurity stats for current node + * @return (best split, statistics for split) If no valid split was found, the returned + * ImpurityStats will have member stats.valid = false. + */ + private[impl] def chooseSplit( + statsAggregator: DTStatsAggregator, + featureIndex: Int, + featureIndexIdx: Int, + featureSplits: Array[Split], + parentCalculator: Option[ImpurityCalculator] = None): (Split, ImpurityStats) = { + val metadata = statsAggregator.metadata + if (metadata.isCategorical(featureIndex)) { + if (metadata.isUnordered(featureIndex)) { + SplitUtils.chooseUnorderedCategoricalSplit(statsAggregator, featureIndex, + featureIndexIdx, featureSplits, parentCalculator) + } else { + SplitUtils.chooseOrderedCategoricalSplit(statsAggregator, featureIndex, + featureIndexIdx, parentCalculator) + } + } else { + SplitUtils.chooseContinuousSplit(statsAggregator, featureIndex, featureIndexIdx, + featureSplits, parentCalculator) + } + + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TrainingInfo.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TrainingInfo.scala new file mode 100644 index 0000000000000..490afb2f53ad3 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TrainingInfo.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.ml.tree.{LearningNode, Split} +import org.apache.spark.util.collection.BitSet + +/** + * Maintains intermediate state of data (columns) and tree during local tree training. + * Primary local tree training data structure; contains all information required to describe + * the state of the algorithm at any point during learning.?? + * + * Nodes are indexed left-to-right along the periphery of the tree, with 0-based indices. + * The "periphery" is the set of leaf nodes (active and inactive). + * + * @param columns Array of columns. + * Each column is sorted first by nodes (left-to-right along the tree periphery); + * all columns share this first level of sorting. + * @param nodeOffsets Offsets into the columns indicating the first level of sorting (by node). + * The rows corresponding to the node activeNodes(i) are in the range + * [nodeOffsets(i)(0), nodeOffsets(i)(1)) . + * @param currentLevelActiveNodes Nodes which are active (could still be split). + * Inactive nodes are known to be leaves in the final tree. + */ +private[impl] case class TrainingInfo( + columns: Array[FeatureColumn], + nodeOffsets: Array[(Int, Int)], + currentLevelActiveNodes: Array[LearningNode], + rowIndices: Option[Array[Int]] = None) extends Serializable { + + // pre-allocated temporary buffers that we use to sort + // instances in left and right children during update + val tempVals: Array[Int] = new Array[Int](columns.head.values.length) + + // Array of row indices for feature values, shared across all columns. + // For each column (col) in [[columns]], col(j) is the feature value corresponding to the row + // with index indices(j). + val indices: Array[Int] = rowIndices.getOrElse(columns.head.values.indices.toArray) + + /** For debugging */ + override def toString: String = { + "TrainingInfo(" + + " columns: {\n" + + columns.mkString(",\n") + + " },\n" + + s" nodeOffsets: ${nodeOffsets.mkString(", ")},\n" + + s" activeNodes: ${currentLevelActiveNodes.iterator.mkString(", ")},\n" + + ")\n" + } + + /** + * Update columns and nodeOffsets for the next level of the tree. + * + * Update columns: + * For each (previously) active node, + * Compute bitset indicating whether each training instance under the node splits left/right + * For each column, + * Sort corresponding range of instances based on bitset. + * Update nodeOffsets, activeNodes: + * Split offsets for nodes which split (which can be identified using the bitset). + * + * @return Updated partition info + */ + def update(splits: Array[Array[Split]], newActiveNodes: Array[LearningNode]): TrainingInfo = { + // Create buffers for storing our new arrays of node offsets & impurities + val newNodeOffsets = new ArrayBuffer[(Int, Int)]() + // Update (per-node) sorting of each column to account for creation of new nodes + var nodeIdx = 0 + while (nodeIdx < currentLevelActiveNodes.length) { + val node = currentLevelActiveNodes(nodeIdx) + // Get new active node offsets from active nodes that were split + if (!node.isLeaf) { + // Get split and FeatureVector corresponding to feature for split + val split = node.split.get + val col = columns(split.featureIndex) + val (from, to) = nodeOffsets(nodeIdx) + // Compute bitset indicating whether each training example splits left/right + val bitset = TrainingInfo.bitSetFromSplit(col, from, to, split, splits(split.featureIndex)) + // Update each column according to the bitset + val numRows = to - from + // Allocate shared temp buffers (shared across all columns) for reordering + // feature values/indices for current node. + val tempVals = new Array[Int](numRows) + val numLeftRows = numRows - bitset.cardinality() + // Reorder values for each column + columns.foreach { col => + LocalDecisionTreeUtils.updateArrayForSplit(col.values, from, to, numLeftRows, tempVals, + bitset) + } + // Reorder indices (shared across all columns) + LocalDecisionTreeUtils.updateArrayForSplit(indices, from, to, numLeftRows, tempVals, bitset) + // Add new node offsets to array + val leftIndices = (from, from + numLeftRows) + val rightIndices = (from + numLeftRows, to) + newNodeOffsets ++= Array(leftIndices, rightIndices) + } + nodeIdx += 1 + } + TrainingInfo(columns, newNodeOffsets.toArray, newActiveNodes, Some(indices)) + } + +} + +/** Training-info specific utility methods. */ +private[impl] object TrainingInfo { + /** + * For a given feature, for a given node, apply a split and return a bitset indicating the + * outcome of the split for each instance at that node. + * + * @param col Column for feature + * @param from Start offset in col for the node + * @param to End offset in col for the node + * @param split Split to apply to instances at this node. + * @return Bitset indicating splits for instances at this node. + * These bits are sorted by the row indices. + * bitset(i) = true if ith example for current node splits right, false otherwise. + */ + private[impl] def bitSetFromSplit( + col: FeatureColumn, + from: Int, + to: Int, + split: Split, + featureSplits: Array[Split]): BitSet = { + val bitset = new BitSet(to - from) + from.until(to).foreach { i => + if (!split.shouldGoLeft(col.values(i), featureSplits)) { + bitset.set(i - from) + } + } + bitset + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index f3dbfd96e1815..029a709f553d0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -75,8 +75,9 @@ class InformationGainStats( * @param impurityCalculator impurity statistics for current node * @param leftImpurityCalculator impurity statistics for left child node * @param rightImpurityCalculator impurity statistics for right child node - * @param valid whether the current split satisfies minimum info gain or - * minimum number of instances per node + * @param valid whether the current split should be performed; true if split + * satisfies minimum info gain, minimum number of instances per node, and + * has positive info gain. */ private[spark] class ImpurityStats( val gain: Double, @@ -112,7 +113,7 @@ private[spark] object ImpurityStats { * minimum number of instances per node. */ def getInvalidImpurityStats(impurityCalculator: ImpurityCalculator): ImpurityStats = { - new ImpurityStats(Double.MinValue, impurityCalculator.calculate(), + new ImpurityStats(Double.MinValue, impurity = -1, impurityCalculator, null, null, false) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/LocalDecisionTreeRegressor.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/LocalDecisionTreeRegressor.scala new file mode 100644 index 0000000000000..a9bfe2164b15f --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/LocalDecisionTreeRegressor.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.ml.Predictor +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.regression.DecisionTreeRegressionModel +import org.apache.spark.ml.tree.{DecisionTreeParams, TreeRegressorParams} +import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Dataset + +/** + * Test-only class for fitting a decision tree regressor on a dataset small enough to fit on a + * single machine. + */ +private[impl] final class LocalDecisionTreeRegressor(override val uid: String) + extends Predictor[Vector, LocalDecisionTreeRegressor, DecisionTreeRegressionModel] + with DecisionTreeParams with TreeRegressorParams { + + def this() = this(Identifiable.randomUID("local_dtr")) + + // Override parameter setters from parent trait for Java API compatibility. + override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + + override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + + override def setMinInstancesPerNode(value: Int): this.type = + super.setMinInstancesPerNode(value) + + override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + + override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + + override def setImpurity(value: String): this.type = super.setImpurity(value) + + override def setSeed(value: Long): this.type = super.setSeed(value) + + override def copy(extra: ParamMap): LocalDecisionTreeRegressor = defaultCopy(extra) + + override protected def train(dataset: Dataset[_]): DecisionTreeRegressionModel = { + val categoricalFeatures: Map[Int, Int] = + MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) + val strategy = getOldStrategy(categoricalFeatures) + val model = LocalTreeTests.train(oldDataset, strategy, parentUID = Some(uid), + seed = getSeed) + model.asInstanceOf[DecisionTreeRegressionModel] + } + + /** Create a Strategy instance to use with the old API. */ + private[impl] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = { + super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity, + subsamplingRate = 1.0) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/LocalTreeDataSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/LocalTreeDataSuite.scala new file mode 100644 index 0000000000000..c90b21842cb83 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/LocalTreeDataSuite.scala @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.tree.{CategoricalSplit, ContinuousSplit, LearningNode, Split} +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.collection.BitSet + +/** Suite exercising data structures (FeatureVector, TrainingInfo) for local tree training. */ +class LocalTreeDataSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + test("FeatureVector: updating columns for split") { + val vecLength = 100 + // Create a column of vecLength values + val values = 0.until(vecLength).toArray + val col = FeatureColumn(-1, values) + // Pick a random subset of indices to split left + val rng = new Random(seed = 42) + val leftProb = 0.5 + val (leftIdxs, rightIdxs) = values.indices.partition(_ => rng.nextDouble() < leftProb) + // Determine our expected result after updating for split + val expected = leftIdxs.map(values(_)) ++ rightIdxs.map(values(_)) + // Create a bitset indicating whether each of our values splits left or right + val instanceBitVector = new BitSet(values.length) + rightIdxs.foreach(instanceBitVector.set) + // Update column, compare new values to expected result + val tempVals = new Array[Int](vecLength) + val tempIndices = new Array[Int](vecLength) + LocalDecisionTreeUtils.updateArrayForSplit(col.values, from = 0, to = vecLength, + leftIdxs.length, tempVals, instanceBitVector) + assert(col.values.sameElements(expected)) + } + + /* Check that FeatureVector methods produce expected results */ + test("FeatureVector: constructor and deepCopy") { + // Create a feature vector v, modify a deep copy of v, and check that + // v itself was not modified + val v = new FeatureColumn(1, Array(1, 2, 3)) + val vCopy = v.deepCopy() + vCopy.values(0) = 1000 + assert(v.values(0) !== vCopy.values(0)) + } + + // Get common TrainingInfo for tests + // Data: + // Feature 0 (continuous): [3, 2, 0, 1] + // Feature 1 (categorical):[0, 0, 2, 1] + private def getTrainingInfo(): TrainingInfo = { + val numRows = 4 + // col1 is continuous features + val col1 = FeatureColumn(featureIndex = 0, Array(3, 2, 0, 1)) + // col2 is categorical features + val catFeatureIdx = 1 + val col2 = FeatureColumn(featureIndex = catFeatureIdx, values = Array(0, 0, 2, 1)) + + val nodeOffsets = Array((0, numRows)) + val activeNodes = Array(LearningNode.emptyNode(nodeIndex = -1)) + TrainingInfo(Array(col1, col2), nodeOffsets, activeNodes) + } + + // Check that TrainingInfo correctly updates node offsets, sorts column values during update() + test("TrainingInfo.update(): correctness when splitting on continuous features") { + // Get TrainingInfo + // Feature 0 (continuous): [3, 2, 0, 1] + // Feature 1 (categorical):[0, 0, 2, 1] + val info = getTrainingInfo() + val activeNodes = info.currentLevelActiveNodes + val contFeatureIdx = 0 + + // For continuous feature, active node has a split with threshold 1 + val contNode = activeNodes(contFeatureIdx) + contNode.split = Some(new ContinuousSplit(contFeatureIdx, threshold = 1)) + + // Update TrainingInfo for continuous split + val contValues = info.columns(contFeatureIdx).values + val splits = Array(LocalTreeTests.getContinuousSplits(contValues, contFeatureIdx)) + val newInfo = info.update(splits, newActiveNodes = Array(contNode)) + + assert(newInfo.columns.length === 2) + // Continuous split should send feature values [0, 1] to the left, [3, 2] to the right + // ==> row indices (2, 3) should split left, row indices (0, 1) should split right + val expectedContCol = new FeatureColumn(0, values = Array(0, 1, 3, 2)) + val expectedCatCol = new FeatureColumn(1, values = Array(2, 1, 0, 0)) + val expectedIndices = Array(2, 3, 0, 1) + assert(newInfo.columns(0) === expectedContCol) + assert(newInfo.columns(1) === expectedCatCol) + assert(newInfo.indices === expectedIndices) + // Check that node offsets were updated properly + assert(newInfo.nodeOffsets === Array((0, 2), (2, 4))) + } + + test("TrainingInfo.update(): correctness when splitting on categorical features") { + // Get TrainingInfo + // Feature 0 (continuous): [3, 2, 0, 1] + // Feature 1 (categorical):[0, 0, 2, 1] + val info = getTrainingInfo() + val activeNodes = info.currentLevelActiveNodes + val catFeatureIdx = 1 + + // For categorical feature, active node puts category 2 on left side of split + val catNode = activeNodes(0) + val catSplit = new CategoricalSplit(catFeatureIdx, _leftCategories = Array(2), + numCategories = 3) + catNode.split = Some(catSplit) + + // Update TrainingInfo for categorical split + val splits: Array[Array[Split]] = Array(Array.empty, Array(catSplit)) + val newInfo = info.update(splits, newActiveNodes = Array(catNode)) + + assert(newInfo.columns.length === 2) + // Categorical split should send feature values [2] to the left, [0, 1] to the right + // ==> row 2 should split left, rows [0, 1, 3] should split right + val expectedContCol = new FeatureColumn(0, values = Array(0, 3, 2, 1)) + val expectedCatCol = new FeatureColumn(1, values = Array(2, 0, 0, 1)) + val expectedIndices = Array(2, 0, 1, 3) + assert(newInfo.columns(0) === expectedContCol) + assert(newInfo.columns(1) === expectedCatCol) + assert(newInfo.indices === expectedIndices) + // Check that node offsets were updated properly + assert(newInfo.nodeOffsets === Array((0, 1), (1, 4))) + } + + private def getSetBits(bitset: BitSet): Set[Int] = { + Range(0, bitset.capacity).filter(bitset.get).toSet + } + + test("TrainingInfo.bitSetFromSplit correctness: splitting a single node") { + val featureIndex = 0 + val thresholds = Array(1, 2, 4, 6, 7) + val values = thresholds.indices.toArray + val splits = LocalTreeTests.getContinuousSplits(thresholds, featureIndex) + val col = FeatureColumn(0, values) + val fromOffset = 0 + val toOffset = col.values.length + val numRows = toOffset + // Create split; first three rows (with feature values [1, 2, 4]) should split left, as they + // have feature values <= 5. Last two rows (feature values [6, 7]) should split right. + val split = new ContinuousSplit(0, threshold = 5) + val bitset = TrainingInfo.bitSetFromSplit(col, fromOffset, toOffset, split, splits) + // Check that the last two rows (row indices [3, 4] within the set of rows being split) + // fall on the right side of the split. + assert(getSetBits(bitset) === Set(3, 4)) + } + + test("TrainingInfo.bitSetFromSplit correctness: splitting 2 nodes") { + // Assume there was already 1 split, which split rows (represented by row index) as: + // (0, 2, 4) | (1, 3) + val thresholds = Array(1, 2, 4, 6, 7) + val values = thresholds.indices.toArray + val splits = LocalTreeTests.getContinuousSplits(thresholds, featureIndex = 0) + val col = new FeatureColumn(0, values) + + /** + * Computes a bitset for splitting rows in with indices in [fromOffset, toOffset) using a + * continuous split with the specified threshold. Then, checks that right side of the split + * contains the row indices in expectedRight. + */ + def checkSplit( + fromOffset: Int, + toOffset: Int, + threshold: Double, + expectedRight: Set[Int]): Unit = { + val split = new ContinuousSplit(0, threshold) + val numRows = col.values.length + val bitset = TrainingInfo.bitSetFromSplit(col, fromOffset, toOffset, split, splits) + assert(getSetBits(bitset) === expectedRight) + } + + // Split rows corresponding to left child node (rows [0, 2, 4]) + checkSplit(fromOffset = 0, toOffset = 3, threshold = 0.5, expectedRight = Set(0, 1, 2)) + checkSplit(fromOffset = 0, toOffset = 3, threshold = 1.5, expectedRight = Set(1, 2)) + checkSplit(fromOffset = 0, toOffset = 3, threshold = 2, expectedRight = Set(2)) + checkSplit(fromOffset = 0, toOffset = 3, threshold = 5, expectedRight = Set()) + // Split rows corresponding to right child node (rows [1, 3]) + checkSplit(fromOffset = 3, toOffset = 5, threshold = 1, expectedRight = Set(0, 1)) + checkSplit(fromOffset = 3, toOffset = 5, threshold = 6.5, expectedRight = Set(1)) + checkSplit(fromOffset = 3, toOffset = 5, threshold = 8, expectedRight = Set()) + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/LocalTreeIntegrationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/LocalTreeIntegrationSuite.scala new file mode 100644 index 0000000000000..9b34cc105d8df --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/LocalTreeIntegrationSuite.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.Estimator +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.regression.DecisionTreeRegressor +import org.apache.spark.mllib.tree.DecisionTreeSuite +import org.apache.spark.mllib.util.{LogisticRegressionDataGenerator, MLlibTestSparkContext} +import org.apache.spark.sql.DataFrame + +/** Tests checking equivalence of trees produced by local and distributed tree training. */ +class LocalTreeIntegrationSuite extends SparkFunSuite with MLlibTestSparkContext { + + val medDepthTreeSettings = TreeTests.allParamSettings ++ Map[String, Any]("maxDepth" -> 4) + + /** + * For each (paramName, paramVal) pair in the passed-in map, set the corresponding + * parameter of the passed-in estimator & return the estimator. + */ + private def setParams[E <: Estimator[_]](estimator: E, params: Map[String, Any]): E = { + params.foreach { case (p, v) => + estimator.set(estimator.getParam(p), v) + } + estimator + } + + /** + * Verifies that local tree training & distributed training produce the same tree + * when fit on the same dataset with the same set of params. + */ + private def testEquivalence(train: DataFrame, testParams: Map[String, Any]): Unit = { + val distribTree = setParams(new DecisionTreeRegressor(), testParams) + val localTree = setParams(new LocalDecisionTreeRegressor(), testParams) + val model = distribTree.fit(train) + val localModel = localTree.fit(train) + TreeTests.checkEqual(localModel, model) + } + + test("Local & distributed training produce the same tree on a toy dataset") { + val data = sc.parallelize(Range(0, 8).map(x => LabeledPoint(x, Vectors.dense(x)))) + val df = spark.createDataFrame(data) + testEquivalence(df, TreeTests.allParamSettings) + } + + test("Local & distributed training produce the same tree on a larger toy dataset") { + val data = sc.parallelize(Range(0, 64).map(x => LabeledPoint(x, Vectors.dense(x)))) + val df = spark.createDataFrame(data) + testEquivalence(df, medDepthTreeSettings) + } + + test("Local & distributed training produce same tree on a dataset of categorical features") { + val data = sc.parallelize(DecisionTreeSuite.generateCategoricalDataPoints().map(_.asML)) + // Create a map of categorical feature index to arity; each feature has arity nclasses + val featuresMap: Map[Int, Int] = Map(0 -> 3, 1 -> 3) + // Convert the data RDD to a DataFrame with metadata indicating the arity of each of its + // categorical features + val df = TreeTests.setMetadata(data, featuresMap, numClasses = 2) + testEquivalence(df, TreeTests.allParamSettings) + } + + test("Local & distributed training produce the same tree on a dataset of continuous features") { + val sqlContext = spark.sqlContext + import sqlContext.implicits._ + // Use maxDepth = 5 and default params + val params = medDepthTreeSettings + val data = LogisticRegressionDataGenerator.generateLogisticRDD(spark.sparkContext, + nexamples = 1000, nfeatures = 5, eps = 2.0, nparts = 1, probOne = 0.2) + .map(_.asML).toDF().cache() + testEquivalence(data, params) + } + + test("Local & distributed training produce the same tree on a dataset of constant features") { + // Generate constant, continuous data + val data = sc.parallelize(Range(0, 8).map(_ => LabeledPoint(1, Vectors.dense(1)))) + val df = spark.createDataFrame(data) + testEquivalence(df, TreeTests.allParamSettings) + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/LocalTreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/LocalTreeTests.scala new file mode 100644 index 0000000000000..6bca306410a70 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/LocalTreeTests.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.internal.Logging +import org.apache.spark.ml.classification.DecisionTreeClassificationModel +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.regression.DecisionTreeRegressionModel +import org.apache.spark.ml.tree._ +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} +import org.apache.spark.rdd.RDD + + +/** Object providing test-only methods for local decision tree training. */ +private[impl] object LocalTreeTests extends Logging { + + /** + * Given the root node of a decision tree, returns a corresponding DecisionTreeModel + * @param algo Enum describing the algorithm used to fit the tree + * @param numClasses Number of label classes (for classification trees) + * @param parentUID UID of parent estimator + */ + private[impl] def finalizeTree( + rootNode: Node, + algo: OldAlgo.Algo, + numClasses: Int, + numFeatures: Int, + parentUID: Option[String]): DecisionTreeModel = { + parentUID match { + case Some(uid) => + if (algo == OldAlgo.Classification) { + new DecisionTreeClassificationModel(uid, rootNode, numFeatures = numFeatures, + numClasses = numClasses) + } else { + new DecisionTreeRegressionModel(uid, rootNode, numFeatures = numFeatures) + } + case None => + if (algo == OldAlgo.Classification) { + new DecisionTreeClassificationModel(rootNode, numFeatures = numFeatures, + numClasses = numClasses) + } else { + new DecisionTreeRegressionModel(rootNode, numFeatures = numFeatures) + } + } + } + + /** + * Method to locally train a decision tree model over an RDD. Assumes the RDD is small enough + * to be collected at a single worker and used to fit a decision tree locally. + * Only used for testing. + */ + private[impl] def train( + input: RDD[LabeledPoint], + strategy: OldStrategy, + seed: Long, + parentUID: Option[String] = None): DecisionTreeModel = { + + // Validate input data + require(input.count() > 0, "Local decision tree training requires > 0 training examples.") + val numFeatures = input.first().features.size + require(numFeatures > 0, "Local decision tree training requires > 0 features.") + + // Construct metadata, find splits + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val splits = RandomForest.findSplits(input, metadata, seed) + + // Bin feature values (convert to TreePoint representation). + val treeInput = TreePoint.convertToTreeRDD(input, splits, metadata).collect() + val instanceWeights = Array.fill[Double](treeInput.length)(1.0) + + // Create tree root node + val initialRoot = LearningNode.emptyNode(nodeIndex = 1) + // TODO: Create rng for feature subsampling (using seed), pass to fitNode + // Fit tree + val rootNode = LocalDecisionTree.fitNode(treeInput, instanceWeights, + initialRoot, metadata, splits) + finalizeTree(rootNode, strategy.algo, strategy.numClasses, numFeatures, parentUID) + } + + /** + * Returns an array of continuous splits for the feature with index featureIndex and the passed-in + * set of values. Creates one continuous split per value in values. + */ + private[impl] def getContinuousSplits( + values: Array[Int], + featureIndex: Int): Array[Split] = { + val splits = values.sorted.map { + new ContinuousSplit(featureIndex, _).asInstanceOf[Split] + } + splits + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/LocalTreeUnitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/LocalTreeUnitSuite.scala new file mode 100644 index 0000000000000..61db1e619c83b --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/LocalTreeUnitSuite.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.tree._ +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext + +/** Unit tests for helper classes/methods specific to local tree training */ +class LocalTreeUnitSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + test("Fit a single decision tree regressor on constant features") { + // Generate constant, continuous data + val data = sc.parallelize(Range(0, 8).map(_ => LabeledPoint(1, Vectors.dense(1)))) + val df = spark.sqlContext.createDataFrame(data) + // Initialize estimator + val dt = new LocalDecisionTreeRegressor() + .setFeaturesCol("features") // indexedFeatures + .setLabelCol("label") + .setMaxDepth(3) + // Fit model + val model = dt.fit(df) + assert(model.rootNode.isInstanceOf[LeafNode]) + val root = model.rootNode.asInstanceOf[LeafNode] + assert(root.prediction == 1) + } + + test("Fit a single decision tree regressor on some continuous features") { + // Generate continuous data + val data = sc.parallelize(Range(0, 8).map(x => LabeledPoint(x, Vectors.dense(x)))) + val df = spark.createDataFrame(data) + // Initialize estimator + val dt = new LocalDecisionTreeRegressor() + .setFeaturesCol("features") // indexedFeatures + .setLabelCol("label") + .setMaxDepth(3) + // Fit model + val model = dt.fit(df) + + // Check that model is of depth 3 (the specified max depth) and that leaf/internal nodes have + // the correct class. + // Validate root + assert(model.rootNode.isInstanceOf[InternalNode]) + // Validate first level of tree (nodes with depth = 1) + val root = model.rootNode.asInstanceOf[InternalNode] + assert(root.leftChild.isInstanceOf[InternalNode] && root.rightChild.isInstanceOf[InternalNode]) + // Validate second and third levels of tree (nodes with depth = 2 or 3) + val left = root.leftChild.asInstanceOf[InternalNode] + val right = root.rightChild.asInstanceOf[InternalNode] + val grandkids = Array(left.leftChild, left.rightChild, right.leftChild, right.rightChild) + grandkids.foreach { grandkid => + assert(grandkid.isInstanceOf[InternalNode]) + val grandkidNode = grandkid.asInstanceOf[InternalNode] + assert(grandkidNode.leftChild.isInstanceOf[LeafNode]) + assert(grandkidNode.rightChild.isInstanceOf[LeafNode]) + } + } + + test("Fit deep local trees") { + + /** + * Deep tree test. Tries to fit tree on synthetic data designed to force tree + * to split to specified depth. + */ + def deepTreeTest(depth: Int): Unit = { + val deepTreeData = TreeTests.deepTreeData(sc, depth) + val df = spark.createDataFrame(deepTreeData) + // Construct estimators; single-tree random forest & decision tree regressor. + val localTree = new LocalDecisionTreeRegressor() + .setFeaturesCol("features") // indexedFeatures + .setLabelCol("label") + .setMaxDepth(depth) + .setMinInfoGain(0.0) + + // Fit model, check depth... + val localModel = localTree.fit(df) + assert(localModel.rootNode.subtreeDepth == depth) + } + + // Test small depth tree + deepTreeTest(10) + // Test medium depth tree + deepTreeTest(40) + // Test high depth tree + deepTreeTest(200) + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/LocalTreeUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/LocalTreeUtilsSuite.scala new file mode 100644 index 0000000000000..e2f52c445323f --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/LocalTreeUtilsSuite.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.tree.ContinuousSplit +import org.apache.spark.util.collection.BitSet + +/** Unit tests for helper classes/methods specific to local tree training */ +class LocalTreeUtilsSuite extends SparkFunSuite { + + test("rowToColumnStoreDense: transforms row-major data into a column-major representation") { + // Attempt to transform an empty training dataset + intercept[IllegalArgumentException] { + LocalDecisionTreeUtils.rowToColumnStoreDense(Array.empty) + } + + // Transform a training dataset consisting of a single row + { + val rowLength = 10 + val data = Array(0.until(rowLength).toArray) + val transposed = LocalDecisionTreeUtils.rowToColumnStoreDense(data) + assert(transposed.length == rowLength, + s"Column-major representation of $rowLength-element row " + + s"contained ${transposed.length} elements") + transposed.foreach { col => + assert(col.length == 1, s"Column-major representation of a single row " + + s"contained column of length ${col.length}, expected length: 1") + } + } + + // Transform a dataset consisting of a single column + { + val colSize = 10 + val data = Array.tabulate[Array[Int]](colSize)(Array(_)) + val transposed = LocalDecisionTreeUtils.rowToColumnStoreDense(data) + assert(transposed.length > 0, s"Column-major representation of $colSize-element column " + + s"was empty.") + assert(transposed.length == 1, s"Column-major representation of $colSize-element column " + + s"should be a single array but was ${transposed.length} arrays.") + assert(transposed(0).length == colSize, + s"Column-major representation of $colSize-element column contained " + + s"${transposed(0).length} elements") + } + + // Transform a 2x3 (non-square) dataset + { + val data = Array(Array(0, 1, 2), Array(3, 4, 5)) + val expected = Array(Array(0, 3), Array(1, 4), Array(2, 5)) + val transposed = LocalDecisionTreeUtils.rowToColumnStoreDense(data) + transposed.zip(expected).foreach { case (resultCol, expectedCol) => + assert(resultCol.sameElements(expectedCol), s"Result column" + + s"${resultCol.mkString(", ")} differed from expected col ${expectedCol.mkString(", ")}") + } + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala new file mode 100644 index 0000000000000..71967736e3007 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.tree.{CategoricalSplit, ContinuousSplit, Split} +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.tree.impurity.{Entropy, Impurity} +import org.apache.spark.mllib.tree.model.ImpurityStats +import org.apache.spark.mllib.util.MLlibTestSparkContext + +/** Suite exercising helper methods for making split decisions during decision tree training. */ +class TreeSplitUtilsSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + /** + * Get a DTStatsAggregator for sufficient stat collection/impurity calculation populated + * with the data from the specified training points. + */ + private def getAggregator( + metadata: DecisionTreeMetadata, + col: FeatureColumn, + from: Int, + to: Int, + labels: Array[Double], + featureSplits: Array[Split]): DTStatsAggregator = { + + val statsAggregator = new DTStatsAggregator(metadata, featureSubset = None) + val instanceWeights = Array.fill[Double](col.values.length)(1.0) + val indices = col.values.indices.toArray + AggUpdateUtils.updateParentImpurity(statsAggregator, indices, from, to, instanceWeights, labels) + LocalDecisionTree.updateAggregator(statsAggregator, col, indices, instanceWeights, labels, + from, to, col.featureIndex, featureSplits) + statsAggregator + } + + /** Check that left/right impurities match what we'd expect for a split. */ + private def validateImpurityStats( + impurity: Impurity, + labels: Array[Double], + stats: ImpurityStats, + expectedLeftStats: Array[Double], + expectedRightStats: Array[Double]): Unit = { + // Verify that impurity stats were computed correctly for split + val numClasses = (labels.max + 1).toInt + val fullImpurityStatsArray + = Array.tabulate[Double](numClasses)((label: Int) => labels.count(_ == label).toDouble) + val fullImpurity = Entropy.calculate(fullImpurityStatsArray, labels.length) + assert(stats.impurityCalculator.stats === fullImpurityStatsArray) + assert(stats.impurity === fullImpurity) + assert(stats.leftImpurityCalculator.stats === expectedLeftStats) + assert(stats.rightImpurityCalculator.stats === expectedRightStats) + assert(stats.valid) + } + + /* * * * * * * * * * * Choosing Splits * * * * * * * * * * */ + + test("chooseSplit: choose correct type of split (continuous split)") { + // Construct (binned) continuous data + val labels = Array(0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0) + val col = FeatureColumn(featureIndex = 0, values = Array(8, 1, 1, 2, 3, 5, 6)) + // Get an array of continuous splits corresponding to values in our binned data + val splits = LocalTreeTests.getContinuousSplits(1.to(8).toArray, featureIndex = 0) + // Construct DTStatsAggregator, compute sufficient stats + val metadata = TreeTests.getMetadata(numExamples = 7, + numFeatures = 1, numClasses = 2, Map.empty) + val statsAggregator = getAggregator(metadata, col, from = 1, to = 4, labels, splits) + // Choose split, check that it's a valid ContinuousSplit + val (split1, stats1) = SplitUtils.chooseSplit(statsAggregator, col.featureIndex, + col.featureIndex, splits) + assert(stats1.valid && split1.isInstanceOf[ContinuousSplit]) + } + + test("chooseSplit: choose correct type of split (categorical split)") { + // Construct categorical data + val labels = Array(0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0) + val featureIndex = 0 + val featureArity = 3 + val values = Array(0, 0, 1, 1, 1, 2, 2) + val col = FeatureColumn(featureIndex, values) + // Construct DTStatsAggregator, compute sufficient stats + val metadata = TreeTests.getMetadata(numExamples = 7, + numFeatures = 1, numClasses = 2, Map(featureIndex -> featureArity)) + val splits = RandomForest.findUnorderedSplits(metadata, featureIndex) + val statsAggregator = getAggregator(metadata, col, from = 1, to = 4, labels, splits) + // Choose split, check that it's a valid categorical split + val (split2, stats2) = SplitUtils.chooseSplit(statsAggregator = statsAggregator, + featureIndex = col.featureIndex, featureIndexIdx = col.featureIndex, + featureSplits = splits) + assert(stats2.valid && split2.isInstanceOf[CategoricalSplit]) + } + + test("chooseOrderedCategoricalSplit: basic case") { + // Helper method for testing ordered categorical split + def testHelper( + values: Array[Int], + labels: Array[Double], + expectedLeftCategories: Array[Double], + expectedLeftStats: Array[Double], + expectedRightStats: Array[Double]): Unit = { + val featureIndex = 0 + // Construct FeatureVector to store categorical data + val featureArity = values.max + 1 + val arityMap = Map[Int, Int](featureIndex -> featureArity) + val col = FeatureColumn(featureIndex = 0, values = values) + // Construct DTStatsAggregator, compute sufficient stats + val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1, + numClasses = 2, arityMap, unorderedFeatures = Some(Set.empty)) + val statsAggregator = getAggregator(metadata, col, from = 0, to = values.length, + labels, featureSplits = Array.empty) + // Choose split + val (split, stats) = + SplitUtils.chooseOrderedCategoricalSplit(statsAggregator, col.featureIndex, + col.featureIndex) + // Verify that split has the expected left-side/right-side categories + val expectedRightCategories = Range(0, featureArity) + .filter(c => !expectedLeftCategories.contains(c)).map(_.toDouble).toArray + split match { + case s: CategoricalSplit => + assert(s.featureIndex === featureIndex) + assert(s.leftCategories === expectedLeftCategories) + assert(s.rightCategories === expectedRightCategories) + case _ => + throw new AssertionError( + s"Expected CategoricalSplit but got ${split.getClass.getSimpleName}") + } + validateImpurityStats(Entropy, labels, stats, expectedLeftStats, expectedRightStats) + } + + val values = Array(0, 0, 1, 2, 2, 2, 2) + val labels1 = Array(0, 0, 1, 1, 1, 1, 1).map(_.toDouble) + testHelper(values, labels1, Array(0.0), Array(2.0, 0.0), Array(0.0, 5.0)) + + val labels2 = Array(0, 0, 0, 1, 1, 1, 1).map(_.toDouble) + testHelper(values, labels2, Array(0.0, 1.0), Array(3.0, 0.0), Array(0.0, 4.0)) + } + + test("chooseOrderedCategoricalSplit: return bad stats if we should not split") { + // Construct categorical data + val featureIndex = 0 + val values = Array(0, 0, 1, 2, 2, 2, 2) + val featureArity = values.max + 1 + val labels = Array(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0) + val col = FeatureColumn(featureIndex, values) + // Construct DTStatsAggregator, compute sufficient stats + val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1, + numClasses = 2, Map(featureIndex -> featureArity), unorderedFeatures = Some(Set.empty)) + val statsAggregator = getAggregator(metadata, col, from = 0, to = values.length, + labels, featureSplits = Array.empty) + // Choose split, verify that it's invalid + val (_, stats) = SplitUtils.chooseOrderedCategoricalSplit(statsAggregator, col.featureIndex, + col.featureIndex) + assert(!stats.valid) + } + + test("chooseUnorderedCategoricalSplit: basic case") { + val featureIndex = 0 + // Construct data for unordered categorical feature + // label: 0 --> values: 1 + // label: 1 --> values: 0, 2 + // label: 2 --> values: 2 + val values = Array(1, 1, 0, 2, 2) + val featureArity = values.max + 1 + val labels = Array(0.0, 0.0, 1.0, 1.0, 2.0) + val col = FeatureColumn(featureIndex, values) + // Construct DTStatsAggregator, compute sufficient stats + val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1, + numClasses = 3, Map(featureIndex -> featureArity)) + val splits = RandomForest.findUnorderedSplits(metadata, featureIndex) + val statsAggregator = getAggregator(metadata, col, from = 0, to = values.length, + labels, splits) + // Choose split + val (split, stats) = + SplitUtils.chooseUnorderedCategoricalSplit(statsAggregator, col.featureIndex, + col.featureIndex, splits) + // Verify that split has the expected left-side/right-side categories + split match { + case s: CategoricalSplit => + assert(s.featureIndex === featureIndex) + assert(s.leftCategories.toSet === Set(1.0)) + assert(s.rightCategories.toSet === Set(0.0, 2.0)) + case _ => + throw new AssertionError( + s"Expected CategoricalSplit but got ${split.getClass.getSimpleName}") + } + validateImpurityStats(Entropy, labels, stats, expectedLeftStats = Array(2.0, 0.0, 0.0), + expectedRightStats = Array(0.0, 2.0, 1.0)) + } + + test("chooseUnorderedCategoricalSplit: return bad stats if we should not split") { + // Construct data for unordered categorical feature + val featureIndex = 0 + val featureArity = 4 + val values = Array(3, 1, 0, 2, 2) + val labels = Array(1.0, 1.0, 1.0, 1.0, 1.0) + val col = FeatureColumn(featureIndex, values) + // Construct DTStatsAggregator, compute sufficient stats + val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1, + numClasses = 2, Map(featureIndex -> featureArity)) + val splits = RandomForest.findUnorderedSplits(metadata, featureIndex) + val statsAggregator = getAggregator(metadata, col, from = 0, to = values.length, labels, splits) + // Choose split, verify that it's invalid + val (_, stats) = SplitUtils.chooseUnorderedCategoricalSplit(statsAggregator, featureIndex, + featureIndex, splits) + assert(!stats.valid) + } + + test("chooseContinuousSplit: basic case") { + // Construct data for continuous feature + val featureIndex = 0 + val thresholds = Array(0, 1, 2, 3) + val values = thresholds.indices.toArray + val labels = Array(0.0, 0.0, 1.0, 1.0) + val col = FeatureColumn(featureIndex = featureIndex, values = values) + + // Construct DTStatsAggregator, compute sufficient stats + val splits = LocalTreeTests.getContinuousSplits(thresholds, featureIndex) + val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1, + numClasses = 2, Map.empty) + val statsAggregator = getAggregator(metadata, col, from = 0, to = values.length, labels, splits) + + // Choose split, verify that it has expected threshold + val (split, stats) = SplitUtils.chooseContinuousSplit(statsAggregator, featureIndex, + featureIndex, splits) + split match { + case s: ContinuousSplit => + assert(s.featureIndex === featureIndex) + assert(s.threshold === 1) + case _ => + throw new AssertionError( + s"Expected ContinuousSplit but got ${split.getClass.getSimpleName}") + } + // Verify impurity stats of split + validateImpurityStats(Entropy, labels, stats, expectedLeftStats = Array(2.0, 0.0), + expectedRightStats = Array(0.0, 2.0)) + } + + test("chooseContinuousSplit: return bad stats if we should not split") { + // Construct data for continuous feature + val featureIndex = 0 + val thresholds = Array(0, 1, 2, 3) + val values = thresholds.indices.toArray + val labels = Array(0.0, 0.0, 0.0, 0.0, 0.0) + val col = FeatureColumn(featureIndex = featureIndex, values = values) + // Construct DTStatsAggregator, compute sufficient stats + val splits = LocalTreeTests.getContinuousSplits(thresholds, featureIndex) + val metadata = TreeTests.getMetadata(numExamples = values.length, numFeatures = 1, + numClasses = 2, Map.empty[Int, Int]) + val statsAggregator = getAggregator(metadata, col, from = 0, to = values.length, labels, splits) + // Choose split, verify that it's invalid + val (split, stats) = SplitUtils.chooseContinuousSplit(statsAggregator, featureIndex, + featureIndex, splits) + assert(!stats.valid) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index b6894b30b0c2b..ad83f4a8c7e6d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -25,6 +25,7 @@ import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericA import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.tree._ +import org.apache.spark.mllib.tree.impurity.{Entropy, Impurity} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SparkSession} @@ -101,6 +102,35 @@ private[ml] object TreeTests extends SparkFunSuite { data.select(data(featuresColName), data(labelColName).as(labelColName, labelMetadata)) } + /** Returns a DecisionTreeMetadata instance with hard-coded values for use in tests */ + def getMetadata( + numExamples: Int, + numFeatures: Int, + numClasses: Int, + featureArity: Map[Int, Int], + impurity: Impurity = Entropy, + unorderedFeatures: Option[Set[Int]] = None): DecisionTreeMetadata = { + // By default, assume all categorical features within tests + // have small enough arity to be treated as unordered + val unordered = unorderedFeatures.getOrElse(featureArity.keys.toSet) + + // Set numBins appropriately for categorical features + val maxBins = 4 + val numBins: Array[Int] = 0.until(numFeatures).toArray.map { featureIndex => + if (featureArity.contains(featureIndex) && featureArity(featureIndex) > 0) { + featureArity(featureIndex) + } else { + maxBins + } + } + + new DecisionTreeMetadata(numFeatures = numFeatures, numExamples = numExamples, + numClasses = numClasses, maxBins = maxBins, minInfoGain = 0.0, featureArity = featureArity, + unorderedFeatures = unordered, numBins = numBins, impurity = impurity, + quantileStrategy = null, maxDepth = 5, minInstancesPerNode = 1, numTrees = 1, + numFeaturesPerNode = 2) + } + /** * Check if the two trees are exactly the same. * Note: I hesitate to override Node.equals since it could cause problems if users @@ -194,6 +224,26 @@ private[ml] object TreeTests extends SparkFunSuite { new LabeledPoint(14.0, Vectors.dense(Array(5.0))) )) + /** + * Create toy data that can be used for testing deep tree training; the generated data requires + * [[depth]] splits to split fully. Thus a tree fit on the generated data should have a depth of + * [[depth]] (unless splitting halts early due to other constraints e.g. max depth or min + * info gain). + */ + def deepTreeData(sc: SparkContext, depth: Int): RDD[LabeledPoint] = { + // Create a dataset with [[depth]] binary features; a training point has a label of 1 + // iff all features have a value of 1. + sc.parallelize(Range(0, depth + 1).map { idx => + val features = Array.fill[Double](depth)(1) + if (idx == depth) { + LabeledPoint(1.0, Vectors.dense(features)) + } else { + features(idx) = 0.0 + LabeledPoint(0.0, Vectors.dense(features)) + } + }) + } + /** * Mapping from all Params to valid settings which differ from the defaults. * This is useful for tests which need to exercise all Params, such as save/load.