Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: minmingzhu <minming.zhu@intel.com>
  • Loading branch information
minmingzhu committed Apr 24, 2023
1 parent 9182e16 commit d8135f6
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 15 deletions.
20 changes: 10 additions & 10 deletions mllib-dal/src/main/native/DecisionForestRegressorOneAPIImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,9 @@ LearningNode convertsplitToLearningNode(
splitNode.impurity = info.get_impurity();
splitNode.sampleCount = info.get_sample_count();
std::unique_ptr<double[]> arr(new double[classCount]);
for (std::int64_t index_class = 0; index_class < classCount;
++index_class) {
arr[index_class] = 0.0;
}
arr[0] = info.get_sample_count();
arr[1] = 0.0;
arr[2] = 0.0;
splitNode.probability = std::move(arr);
return splitNode;
}
Expand All @@ -78,10 +77,9 @@ convertleafToLearningNode(const df::leaf_node_info<df::task::regression> &info,
leafNode.impurity = info.get_impurity();
leafNode.sampleCount = info.get_sample_count();
std::unique_ptr<double[]> arr(new double[classCount]);
for (std::int64_t index_class = 0; index_class < classCount;
++index_class) {
arr[index_class] = 0.0;
}
arr[0] = info.get_sample_count();
arr[1] = info.get_sample_count() * info.get_label();
arr[2] = info.get_sample_count() * info.get_label() * info.get_label();
leafNode.probability = std::move(arr);
return leafNode;
}
Expand Down Expand Up @@ -294,8 +292,10 @@ static jobject doRFRegressorOneAPICompute(
<< result_infer.get_responses() << std::endl;

// convert c++ map to java hashmap
collect_model(env, result_train.get_model(), 0, treeForest);
trees = convertRFRJavaMap(env, treeForest, 0);
jint statsSize = 3; // spark create VarianceCalculator needs array of
// sufficient statistics
collect_model(env, result_train.get_model(), statsSize, treeForest);
trees = convertRFRJavaMap(env, treeForest, statsSize);

// Get the class of the input object
jclass clazz = env->GetObjectClass(resultObj);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ package org.apache.spark.ml.tree
import com.intel.oap.mllib.classification.{LearningNode => LearningNodeDAL}

import org.apache.spark.ml.tree.impl.DecisionTreeMetadata
import org.apache.spark.mllib.tree.impurity.{GiniCalculator, VarianceCalculator}
import org.apache.spark.mllib.tree.impurity.{Gini, GiniCalculator, ImpurityCalculator, Variance, VarianceCalculator}
import org.apache.spark.mllib.tree.model.ImpurityStats

object TreeUtils {
Expand Down Expand Up @@ -58,10 +58,11 @@ object TreeUtils {
val ln: LearningNodeDAL = nodes.get(i)
i += 1

val impurityCalculator = if (metadata.impurity == "gini") {
new GiniCalculator(ln.probability, ln.sampleCount)
} else {
new VarianceCalculator(ln.probability, ln.sampleCount)
val impurityCalculator: ImpurityCalculator = metadata.impurity match {
case Gini => new GiniCalculator(ln.probability, ln.sampleCount)
case Variance => new VarianceCalculator(ln.probability, ln.sampleCount)
case _ => throw new IllegalArgumentException(s"Bad impurity parameter: " +
s"${metadata.impurity}")
}

val impurityStats = new ImpurityStats(0, ln.impurity, impurityCalculator, null, null)
Expand Down

0 comments on commit d8135f6

Please sign in to comment.