Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,10 @@ object DecisionTreeRunner {
// Create training, test sets.
val splits = if (params.testInput != "") {
// Load testInput.
val numFeatures = examples.take(1)(0).features.size
val origTestExamples = params.dataFormat match {
case "dense" => MLUtils.loadLabeledPoints(sc, params.testInput)
case "libsvm" => MLUtils.loadLibSVMFile(sc, params.testInput)
case "libsvm" => MLUtils.loadLibSVMFile(sc, params.testInput, numFeatures)
}
params.algo match {
case Classification => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,6 @@ private[tree] class DTStatsAggregator(
numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)
}

/**
* Indicator for each feature of whether that feature is an unordered feature.
* TODO: Is Array[Boolean] any faster?
*/
def isUnordered(featureIndex: Int): Boolean = metadata.isUnordered(featureIndex)

/**
* Total number of elements stored in this aggregator
*/
Expand Down Expand Up @@ -128,21 +122,13 @@ private[tree] class DTStatsAggregator(
* Pre-compute feature offset for use with [[featureUpdate]].
* For ordered features only.
*/
def getFeatureOffset(featureIndex: Int): Int = {
require(!isUnordered(featureIndex),
s"DTStatsAggregator.getFeatureOffset is for ordered features only, but was called" +
s" for unordered feature $featureIndex.")
featureOffsets(featureIndex)
}
def getFeatureOffset(featureIndex: Int): Int = featureOffsets(featureIndex)

/**
* Pre-compute feature offset for use with [[featureUpdate]].
* For unordered features only.
*/
def getLeftRightFeatureOffsets(featureIndex: Int): (Int, Int) = {
require(isUnordered(featureIndex),
s"DTStatsAggregator.getLeftRightFeatureOffsets is for unordered features only," +
s" but was called for ordered feature $featureIndex.")
val baseOffset = featureOffsets(featureIndex)
(baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.mllib.tree.impl

import scala.collection.mutable

import org.apache.spark.Logging
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
Expand Down Expand Up @@ -82,7 +83,7 @@ private[tree] class DecisionTreeMetadata(

}

private[tree] object DecisionTreeMetadata {
private[tree] object DecisionTreeMetadata extends Logging {

/**
* Construct a [[DecisionTreeMetadata]] instance for this dataset and parameters.
Expand All @@ -103,6 +104,10 @@ private[tree] object DecisionTreeMetadata {
}

val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt
if (maxPossibleBins < strategy.maxBins) {
logWarning(s"DecisionTree reducing maxBins from ${strategy.maxBins} to $maxPossibleBins" +
s" (= number of training instances)")
}

// We check the number of bins here against maxPossibleBins.
// This needs to be checked here instead of in Strategy since maxPossibleBins can be modified
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@

package org.apache.spark.mllib.tree.model

import org.apache.spark.annotation.DeveloperApi

/**
* Predicted value for a node
* @param predict predicted value
* @param prob probability of the label (classification only)
*/
private[tree] class Predict(
@DeveloperApi
class Predict(
val predict: Double,
val prob: Double = 0.0) extends Serializable {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ class RandomForestModel(val trees: Array[DecisionTreeModel], val algo: Algo) ext
*/
override def toString: String = algo match {
case Classification =>
s"RandomForestModel classifier with $numTrees trees"
s"RandomForestModel classifier with $numTrees trees and $totalNumNodes total nodes"
case Regression =>
s"RandomForestModel regressor with $numTrees trees"
s"RandomForestModel regressor with $numTrees trees and $totalNumNodes total nodes"
case _ => throw new IllegalArgumentException(
s"RandomForestModel given unknown algo parameter: $algo.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,22 @@ class RandomForestSuite extends FunSuite with LocalSparkContext {
checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt)
}

test("alternating categorical and continuous features with multiclass labels to test indexing") {
val arr = new Array[LabeledPoint](4)
arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0, 3.0, 1.0))
arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0, 1.0, 2.0))
arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0, 6.0, 3.0))
arr(3) = new LabeledPoint(2.0, Vectors.dense(0.0, 2.0, 1.0, 3.0, 2.0))
val categoricalFeaturesInfo = Map(0 -> 3, 2 -> 2, 4 -> 4)
val input = sc.parallelize(arr)

val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
numClassesForClassification = 3, categoricalFeaturesInfo = categoricalFeaturesInfo)
val model = RandomForest.trainClassifier(input, strategy, numTrees = 2,
featureSubsetStrategy = "sqrt", seed = 12345)
RandomForestSuite.validateClassifier(model, arr, 1.0)
}

}

object RandomForestSuite {
Expand Down