Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
package org.apache.spark.ml.tree.impl

import scala.collection.mutable
import scala.util.Try

import org.apache.spark.internal.Logging
import org.apache.spark.ml.tree.RandomForestParams
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
Expand Down Expand Up @@ -184,15 +186,22 @@ private[spark] object DecisionTreeMetadata extends Logging {
case _ => featureSubsetStrategy
}

val isIntRegex = "^([1-9]\\d*)$".r
val isFractionRegex = "^(0?\\.\\d*[1-9]\\d*|1\\.0+)$".r
val numFeaturesPerNode: Int = _featureSubsetStrategy match {
case "all" => numFeatures
case "sqrt" => math.sqrt(numFeatures).ceil.toInt
case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt)
case "onethird" => (numFeatures / 3.0).ceil.toInt
case isIntRegex(number) => if (BigInt(number) > numFeatures) numFeatures else number.toInt
case isFractionRegex(fraction) => (fraction.toDouble * numFeatures).ceil.toInt
case _ =>
Try(_featureSubsetStrategy.toInt).filter(_ > 0).toOption match {
case Some(value) => math.min(value, numFeatures)
case None =>
Try(_featureSubsetStrategy.toDouble).filter(_ > 0).filter(_ <= 1.0).toOption match {
case Some(value) => math.ceil(value * numFeatures).toInt
case _ => throw new IllegalArgumentException(s"Supported values:" +
s" ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}," +
s" (0.0-1.0], [1-n].")
}
}
}

new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
Expand Down
11 changes: 6 additions & 5 deletions mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.ml.tree

import scala.util.Try

import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
Expand Down Expand Up @@ -346,10 +348,12 @@ private[ml] trait HasFeatureSubsetStrategy extends Params {
*/
final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy",
"The number of features to consider for splits at each tree node." +
s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}",
s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}" +
s", (0.0-1.0], [1-n].",
(value: String) =>
RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase)
|| value.matches(RandomForestParams.supportedFeatureSubsetStrategiesRegex))
|| Try(value.toInt).filter(_ > 0).isSuccess
|| Try(value.toDouble).filter(_ > 0).filter(_ <= 1.0).isSuccess)

setDefault(featureSubsetStrategy -> "auto")

Expand Down Expand Up @@ -396,9 +400,6 @@ private[spark] object RandomForestParams {
// These options should be lowercase.
final val supportedFeatureSubsetStrategies: Array[String] =
Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase)

// The regex to capture "(0.0-1.0]", and "n" for integer 0 < n <= (number of features)
final val supportedFeatureSubsetStrategiesRegex = "^(?:[1-9]\\d*|0?\\.\\d*[1-9]\\d*|1\\.0+)$"
}

private[ml] trait RandomForestClassifierParams
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.mllib.tree

import scala.collection.JavaConverters._
import scala.util.Try

import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD
Expand Down Expand Up @@ -76,9 +77,10 @@ private class RandomForest (
strategy.assertValid()
require(numTrees > 0, s"RandomForest requires numTrees > 0, but was given numTrees = $numTrees.")
require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy)
|| featureSubsetStrategy.matches(NewRFParams.supportedFeatureSubsetStrategiesRegex),
|| Try(featureSubsetStrategy.toInt).filter(_ > 0).isSuccess
|| Try(featureSubsetStrategy.toDouble).filter(_ > 0).filter(_ <= 1.0).isSuccess,
s"RandomForest given invalid featureSubsetStrategy: $featureSubsetStrategy." +
s" Supported values: ${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}," +
s" Supported values: ${NewRFParams.supportedFeatureSubsetStrategies.mkString(", ")}," +
s" (0.0-1.0], [1-n].")

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {

val invalidStrategies = Array("-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0")
for (invalidStrategy <- invalidStrategies) {
intercept[MatchError]{
intercept[IllegalArgumentException]{
val metadata =
DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 1, invalidStrategy)
}
Expand All @@ -463,7 +463,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
checkFeatureSubsetStrategy(numTrees = 2, strategy, expected)
}
for (invalidStrategy <- invalidStrategies) {
intercept[MatchError]{
intercept[IllegalArgumentException]{
val metadata =
DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 2, invalidStrategy)
}
Expand Down