From 9182e161b74ee6e14980a10df080f639b30c7782 Mon Sep 17 00:00:00 2001 From: minmingzhu Date: Sun, 23 Apr 2023 15:33:20 +0000 Subject: [PATCH] update Signed-off-by: minmingzhu --- examples/random-forest/regression/pom.xml | 2 +- .../ml/regression/spark321/RandomForestRegressor.scala | 4 ++-- .../main/scala/org/apache/spark/ml/tree/TreeUtils.scala | 9 +++++++-- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/examples/random-forest/regression/pom.xml b/examples/random-forest/regression/pom.xml index 403e842fe..30e2a4566 100644 --- a/examples/random-forest/regression/pom.xml +++ b/examples/random-forest/regression/pom.xml @@ -7,7 +7,7 @@ 1.3.1 jar - KMeansExample + RandomForestRegressorExample https://github.com/oap-project/oap-mllib.git diff --git a/mllib-dal/src/main/scala/org/apache/spark/ml/regression/spark321/RandomForestRegressor.scala b/mllib-dal/src/main/scala/org/apache/spark/ml/regression/spark321/RandomForestRegressor.scala index 756d9e3fe..d8ad470ec 100644 --- a/mllib-dal/src/main/scala/org/apache/spark/ml/regression/spark321/RandomForestRegressor.scala +++ b/mllib-dal/src/main/scala/org/apache/spark/ml/regression/spark321/RandomForestRegressor.scala @@ -209,8 +209,8 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S => { val rootNode = TreeUtils.buildTreeDFS(nodelist, metadata) new DecisionTreeRegressionModel(uid, - rootNode.toNode(), - numFeatures) + rootNode.toNode(), + numFeatures) } }.toArray } diff --git a/mllib-dal/src/main/scala/org/apache/spark/ml/tree/TreeUtils.scala b/mllib-dal/src/main/scala/org/apache/spark/ml/tree/TreeUtils.scala index f96b8ae67..4e94556cf 100644 --- a/mllib-dal/src/main/scala/org/apache/spark/ml/tree/TreeUtils.scala +++ b/mllib-dal/src/main/scala/org/apache/spark/ml/tree/TreeUtils.scala @@ -21,7 +21,7 @@ package org.apache.spark.ml.tree import com.intel.oap.mllib.classification.{LearningNode => LearningNodeDAL} import org.apache.spark.ml.tree.impl.DecisionTreeMetadata -import org.apache.spark.mllib.tree.impurity.GiniCalculator +import org.apache.spark.mllib.tree.impurity.{GiniCalculator, VarianceCalculator} import org.apache.spark.mllib.tree.model.ImpurityStats object TreeUtils { @@ -58,7 +58,12 @@ object TreeUtils { val ln: LearningNodeDAL = nodes.get(i) i += 1 - val impurityCalculator = new GiniCalculator(ln.probability, ln.sampleCount) + val impurityCalculator = if (metadata.impurity == "gini") { + new GiniCalculator(ln.probability, ln.sampleCount) + } else { + new VarianceCalculator(ln.probability, ln.sampleCount) + } + val impurityStats = new ImpurityStats(0, ln.impurity, impurityCalculator, null, null) val node = LearningNode.apply(0, ln.isLeaf, impurityStats) node.split = if (!ln.isLeaf) {