Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: minmingzhu <minming.zhu@intel.com>
  • Loading branch information
minmingzhu committed Apr 23, 2023
1 parent e74ca63 commit 9182e16
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 5 deletions.
2 changes: 1 addition & 1 deletion examples/random-forest/regression/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
<version>1.3.1</version>
<packaging>jar</packaging>

<name>KMeansExample</name>
<name>RandomForestRegressorExample</name>
<url>https://github.com/oap-project/oap-mllib.git</url>

<properties>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
=> {
val rootNode = TreeUtils.buildTreeDFS(nodelist, metadata)
new DecisionTreeRegressionModel(uid,
rootNode.toNode(),
numFeatures)
rootNode.toNode(),
numFeatures)
}
}.toArray
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ package org.apache.spark.ml.tree
import com.intel.oap.mllib.classification.{LearningNode => LearningNodeDAL}

import org.apache.spark.ml.tree.impl.DecisionTreeMetadata
import org.apache.spark.mllib.tree.impurity.GiniCalculator
import org.apache.spark.mllib.tree.impurity.{GiniCalculator, VarianceCalculator}
import org.apache.spark.mllib.tree.model.ImpurityStats

object TreeUtils {
Expand Down Expand Up @@ -58,7 +58,12 @@ object TreeUtils {
val ln: LearningNodeDAL = nodes.get(i)
i += 1

val impurityCalculator = new GiniCalculator(ln.probability, ln.sampleCount)
val impurityCalculator = if (metadata.impurity == "gini") {
new GiniCalculator(ln.probability, ln.sampleCount)
} else {
new VarianceCalculator(ln.probability, ln.sampleCount)
}

val impurityStats = new ImpurityStats(0, ln.impurity, impurityCalculator, null, null)
val node = LearningNode.apply(0, ln.isLeaf, impurityStats)
node.split = if (!ln.isLeaf) {
Expand Down

0 comments on commit 9182e16

Please sign in to comment.