Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,16 @@ private[spark] object DecisionTreeMetadata extends Logging {
}
case _ => featureSubsetStrategy
}

val isIntRegex = "^([1-9]\\d*)$".r
val isFractionRegex = "^(0?\\.\\d*[1-9]\\d*|1\\.0+)$".r
val numFeaturesPerNode: Int = _featureSubsetStrategy match {
case "all" => numFeatures
case "sqrt" => math.sqrt(numFeatures).ceil.toInt
case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt)
case "onethird" => (numFeatures / 3.0).ceil.toInt
case isIntRegex(number) => if (BigInt(number) > numFeatures) numFeatures else number.toInt
case isFractionRegex(fraction) => (fraction.toDouble * numFeatures).ceil.toInt
}

new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,8 @@ private[ml] trait HasFeatureSubsetStrategy extends Params {
* - "onethird": use 1/3 of the features
* - "sqrt": use sqrt(number of features)
* - "log2": use log2(number of features)
* - "n": when n is in the range (0, 1.0], use n * number of features. When n
* is in the range (1, number of features), use n features.
* (default = "auto")
*
* These various settings are based on the following references:
Expand All @@ -346,7 +348,8 @@ private[ml] trait HasFeatureSubsetStrategy extends Params {
"The number of features to consider for splits at each tree node." +
s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}",
(value: String) =>
RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase))
RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase)
|| value.matches(RandomForestParams.supportedFeatureSubsetStrategiesRegex))

setDefault(featureSubsetStrategy -> "auto")

Expand Down Expand Up @@ -393,6 +396,9 @@ private[spark] object RandomForestParams {
// These options should be lowercase.
final val supportedFeatureSubsetStrategies: Array[String] =
Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase)

// The regex to capture "(0.0-1.0]", and "n" for integer 0 < n <= (number of features)
final val supportedFeatureSubsetStrategiesRegex = "^(?:[1-9]\\d*|0?\\.\\d*[1-9]\\d*|1\\.0+)$"
}

private[ml] trait RandomForestClassifierParams
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,15 @@ import org.apache.spark.util.Utils
* @param numTrees If 1, then no bootstrapping is used. If > 1, then bootstrapping is done.
* @param featureSubsetStrategy Number of features to consider for splits at each node.
* Supported values: "auto", "all", "sqrt", "log2", "onethird".
* Supported numerical values: "(0.0-1.0]", "[1-n]".
* If "auto" is set, this parameter is set based on numTrees:
* if numTrees == 1, set to "all";
* if numTrees > 1 (forest) set to "sqrt" for classification and
* to "onethird" for regression.
* If a real value "n" in the range (0, 1.0] is set,
* use n * number of features.
* If an integer value "n" in the range (1, num features) is set,
* use n features.
* @param seed Random seed for bootstrapping and choosing feature subsets.
*/
private class RandomForest (
Expand All @@ -70,9 +75,11 @@ private class RandomForest (

strategy.assertValid()
require(numTrees > 0, s"RandomForest requires numTrees > 0, but was given numTrees = $numTrees.")
require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy),
require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy)
|| featureSubsetStrategy.matches(NewRFParams.supportedFeatureSubsetStrategiesRegex),
s"RandomForest given invalid featureSubsetStrategy: $featureSubsetStrategy." +
s" Supported values: ${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}.")
s" Supported values: ${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}," +
s" (0.0-1.0], [1-n].")

/**
* Method to train a decision tree model over an RDD
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.Map;

import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

Expand Down Expand Up @@ -80,6 +81,24 @@ public void runDT() {
for (String featureSubsetStrategy: RandomForestClassifier.supportedFeatureSubsetStrategies()) {
rf.setFeatureSubsetStrategy(featureSubsetStrategy);
}
String realStrategies[] = {".1", ".10", "0.10", "0.1", "0.9", "1.0"};
for (String strategy: realStrategies) {
rf.setFeatureSubsetStrategy(strategy);
}
String integerStrategies[] = {"1", "10", "100", "1000", "10000"};
for (String strategy: integerStrategies) {
rf.setFeatureSubsetStrategy(strategy);
}
String invalidStrategies[] = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"};
for (String strategy: invalidStrategies) {
try {
rf.setFeatureSubsetStrategy(strategy);
Assert.fail("Expected exception to be thrown for invalid strategies");
} catch (Exception e) {
Assert.assertTrue(e instanceof IllegalArgumentException);
}
}

RandomForestClassificationModel model = rf.fit(dataFrame);

model.transform(dataFrame);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.Map;

import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

Expand Down Expand Up @@ -80,6 +81,24 @@ public void runDT() {
for (String featureSubsetStrategy: RandomForestRegressor.supportedFeatureSubsetStrategies()) {
rf.setFeatureSubsetStrategy(featureSubsetStrategy);
}
String realStrategies[] = {".1", ".10", "0.10", "0.1", "0.9", "1.0"};
for (String strategy: realStrategies) {
rf.setFeatureSubsetStrategy(strategy);
}
String integerStrategies[] = {"1", "10", "100", "1000", "10000"};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing 0 should round up to 1, yes? We should test this edge case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, what happens with negative values? Those should not be allowed - just want to confirm the regex excludes that (we should add some test cases)

for (String strategy: integerStrategies) {
rf.setFeatureSubsetStrategy(strategy);
}
String invalidStrategies[] = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"};
for (String strategy: invalidStrategies) {
try {
rf.setFeatureSubsetStrategy(strategy);
Assert.fail("Expected exception to be thrown for invalid strategies");
} catch (Exception e) {
Assert.assertTrue(e instanceof IllegalArgumentException);
}
}

RandomForestRegressionModel model = rf.fit(dataFrame);

model.transform(dataFrame);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -423,12 +423,48 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
(math.log(numFeatures) / math.log(2)).ceil.toInt)
checkFeatureSubsetStrategy(numTrees = 1, "onethird", (numFeatures / 3.0).ceil.toInt)

val realStrategies = Array(".1", ".10", "0.10", "0.1", "0.9", "1.0")
for (strategy <- realStrategies) {
val expected = (strategy.toDouble * numFeatures).ceil.toInt
checkFeatureSubsetStrategy(numTrees = 1, strategy, expected)
}

val integerStrategies = Array("1", "10", "100", "1000", "10000")
for (strategy <- integerStrategies) {
val expected = if (strategy.toInt < numFeatures) strategy.toInt else numFeatures
checkFeatureSubsetStrategy(numTrees = 1, strategy, expected)
}

val invalidStrategies = Array("-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0")
for (invalidStrategy <- invalidStrategies) {
intercept[MatchError]{
val metadata =
DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 1, invalidStrategy)
}
}

checkFeatureSubsetStrategy(numTrees = 2, "all", numFeatures)
checkFeatureSubsetStrategy(numTrees = 2, "auto", math.sqrt(numFeatures).ceil.toInt)
checkFeatureSubsetStrategy(numTrees = 2, "sqrt", math.sqrt(numFeatures).ceil.toInt)
checkFeatureSubsetStrategy(numTrees = 2, "log2",
(math.log(numFeatures) / math.log(2)).ceil.toInt)
checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt)

for (strategy <- realStrategies) {
val expected = (strategy.toDouble * numFeatures).ceil.toInt
checkFeatureSubsetStrategy(numTrees = 2, strategy, expected)
}

for (strategy <- integerStrategies) {
val expected = if (strategy.toInt < numFeatures) strategy.toInt else numFeatures
checkFeatureSubsetStrategy(numTrees = 2, strategy, expected)
}
for (invalidStrategy <- invalidStrategies) {
intercept[MatchError]{
val metadata =
DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 2, invalidStrategy)
}
}
}

test("Binary classification with continuous features: subsampling features") {
Expand Down