diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala index c7cde1563fc79..5f7c40f6071f6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala @@ -18,8 +18,10 @@ package org.apache.spark.ml.tree.impl import scala.collection.mutable +import scala.util.Try import org.apache.spark.internal.Logging +import org.apache.spark.ml.tree.RandomForestParams import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ @@ -184,15 +186,22 @@ private[spark] object DecisionTreeMetadata extends Logging { case _ => featureSubsetStrategy } - val isIntRegex = "^([1-9]\\d*)$".r - val isFractionRegex = "^(0?\\.\\d*[1-9]\\d*|1\\.0+)$".r val numFeaturesPerNode: Int = _featureSubsetStrategy match { case "all" => numFeatures case "sqrt" => math.sqrt(numFeatures).ceil.toInt case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt) case "onethird" => (numFeatures / 3.0).ceil.toInt - case isIntRegex(number) => if (BigInt(number) > numFeatures) numFeatures else number.toInt - case isFractionRegex(fraction) => (fraction.toDouble * numFeatures).ceil.toInt + case _ => + Try(_featureSubsetStrategy.toInt).filter(_ > 0).toOption match { + case Some(value) => math.min(value, numFeatures) + case None => + Try(_featureSubsetStrategy.toDouble).filter(_ > 0).filter(_ <= 1.0).toOption match { + case Some(value) => math.ceil(value * numFeatures).toInt + case _ => throw new IllegalArgumentException(s"Supported values:" + + s" ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}," + + s" (0.0-1.0], [1-n].") + } + } } new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max, diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 0767dc17e5562..55e12acc1c51e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.tree +import scala.util.Try + import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -346,10 +348,12 @@ private[ml] trait HasFeatureSubsetStrategy extends Params { */ final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy", "The number of features to consider for splits at each tree node." + - s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}", + s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}" + + s", (0.0-1.0], [1-n].", (value: String) => RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase) - || value.matches(RandomForestParams.supportedFeatureSubsetStrategiesRegex)) + || Try(value.toInt).filter(_ > 0).isSuccess + || Try(value.toDouble).filter(_ > 0).filter(_ <= 1.0).isSuccess) setDefault(featureSubsetStrategy -> "auto") @@ -396,9 +400,6 @@ private[spark] object RandomForestParams { // These options should be lowercase. final val supportedFeatureSubsetStrategies: Array[String] = Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase) - - // The regex to capture "(0.0-1.0]", and "n" for integer 0 < n <= (number of features) - final val supportedFeatureSubsetStrategiesRegex = "^(?:[1-9]\\d*|0?\\.\\d*[1-9]\\d*|1\\.0+)$" } private[ml] trait RandomForestClassifierParams diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index 26755849ad1a2..ca7fb7f51c3fc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.tree import scala.collection.JavaConverters._ +import scala.util.Try import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD @@ -76,9 +77,10 @@ private class RandomForest ( strategy.assertValid() require(numTrees > 0, s"RandomForest requires numTrees > 0, but was given numTrees = $numTrees.") require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy) - || featureSubsetStrategy.matches(NewRFParams.supportedFeatureSubsetStrategiesRegex), + || Try(featureSubsetStrategy.toInt).filter(_ > 0).isSuccess + || Try(featureSubsetStrategy.toDouble).filter(_ > 0).filter(_ <= 1.0).isSuccess, s"RandomForest given invalid featureSubsetStrategy: $featureSubsetStrategy." + - s" Supported values: ${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}," + + s" Supported values: ${NewRFParams.supportedFeatureSubsetStrategies.mkString(", ")}," + s" (0.0-1.0], [1-n].") /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 6db9ce150d930..1719f9fab5345 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -440,7 +440,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { val invalidStrategies = Array("-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0") for (invalidStrategy <- invalidStrategies) { - intercept[MatchError]{ + intercept[IllegalArgumentException]{ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 1, invalidStrategy) } @@ -463,7 +463,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { checkFeatureSubsetStrategy(numTrees = 2, strategy, expected) } for (invalidStrategy <- invalidStrategies) { - intercept[MatchError]{ + intercept[IllegalArgumentException]{ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 2, invalidStrategy) }