Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: minmingzhu <minming.zhu@intel.com>
  • Loading branch information
minmingzhu committed Apr 21, 2023
1 parent bb7ec7e commit c965934
Showing 1 changed file with 3 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

package org.apache.spark.ml.classification

import org.apache.spark.{SparkConf, SparkFunSuite}
import com.intel.oneapi.dal.table.Common
import org.apache.spark.{SparkConf, SparkFunSuite, TestCommon}
import org.apache.spark.ml.classification.LinearSVCSuite.generateSVMInput
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{Vector, Vectors}
Expand All @@ -42,7 +43,7 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest {
import testImplicits._
override def sparkConf: SparkConf = {
val conf = super.sparkConf
conf.set("spark.oap.mllib.device", "GPU")
conf.set("spark.oap.mllib.device", Common.ComputeDevice.GPU.toString)
}

private var orderedLabeledPoints50_1000: RDD[LabeledPoint] = _
Expand All @@ -65,78 +66,13 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest {
// Tests calling train()
/////////////////////////////////////////////////////////////////////////////

// def binaryClassificationTestWithContinuousFeatures(rf: RandomForestClassifier): Unit = {
// val categoricalFeatures = Map.empty[Int, Int]
// val numClasses = 2
// val newRF = rf
// .setImpurity("Gini")
// .setMaxDepth(2)
// .setNumTrees(1)
// .setFeatureSubsetStrategy("auto")
// .setSeed(123)
// compareAPIs(orderedLabeledPoints50_1000, newRF, categoricalFeatures, numClasses)
// }

test("params") {
ParamsSuite.checkParams(new RandomForestClassifier)
val model = new RandomForestClassificationModel("rfc",
Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 1, 2)), 2, 2)
ParamsSuite.checkParams(model)
}

// test("Binary classification with continuous features:" +
// " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
// val rf = new RandomForestClassifier()
// .setBootstrap(false)
// binaryClassificationTestWithContinuousFeatures(rf)
// }

// test("Binary classification with continuous features and node Id cache:" +
// " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
// val rf = new RandomForestClassifier()
// .setBootstrap(false)
// .setCacheNodeIds(true)
// binaryClassificationTestWithContinuousFeatures(rf)
// }

// test("alternating categorical and continuous features with multiclass labels to test indexing") {
// val arr = Array(
// LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0, 3.0, 1.0)),
// LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0, 1.0, 2.0)),
// LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0, 6.0, 3.0)),
// LabeledPoint(2.0, Vectors.dense(0.0, 2.0, 1.0, 3.0, 2.0))
// )
// val rdd = sc.parallelize(arr)
// val categoricalFeatures = Map(0 -> 3, 2 -> 2, 4 -> 4)
// val numClasses = 3
//
// val rf = new RandomForestClassifier()
// .setImpurity("Gini")
// .setMaxDepth(5)
// .setNumTrees(2)
// .setFeatureSubsetStrategy("sqrt")
// .setSeed(12345)
// compareAPIs(rdd, rf, categoricalFeatures, numClasses)
// }

// test("subsampling rate in RandomForest") {
// val rdd = orderedLabeledPoints5_20
// val categoricalFeatures = Map.empty[Int, Int]
// val numClasses = 2
//
// val rf1 = new RandomForestClassifier()
// .setImpurity("Gini")
// .setMaxDepth(2)
// .setCacheNodeIds(true)
// .setNumTrees(3)
// .setFeatureSubsetStrategy("auto")
// .setSeed(123)
// compareAPIs(rdd, rf1, categoricalFeatures, numClasses)
//
// val rf2 = rf1.setSubsamplingRate(0.5)
// compareAPIs(rdd, rf2, categoricalFeatures, numClasses)
// }

test("predictRaw and predictProbability") {
val rdd = orderedLabeledPoints5_20
val rf = new RandomForestClassifier()
Expand Down

0 comments on commit c965934

Please sign in to comment.