diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala index f4e2040569f4..3b6bb2ad1357 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala @@ -22,6 +22,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.mllib.evaluation.binary._ import org.apache.spark.rdd.{RDD, UnionRDD} import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.storage.StorageLevel /** * Evaluator for binary classification. @@ -165,13 +166,17 @@ class BinaryClassificationMetrics @Since("3.0.0") ( confusions: RDD[(Double, BinaryConfusionMatrix)]) = { // Create a bin for each distinct score value, count weighted positives and // negatives within each bin, and then sort by score values in descending order. - val counts = scoreLabelsWeight.combineByKey( + val binnedWeights = scoreLabelsWeight.combineByKey( createCombiner = (labelAndWeight: (Double, Double)) => new BinaryLabelCounter(0.0, 0.0) += (labelAndWeight._1, labelAndWeight._2), mergeValue = (c: BinaryLabelCounter, labelAndWeight: (Double, Double)) => c += (labelAndWeight._1, labelAndWeight._2), mergeCombiners = (c1: BinaryLabelCounter, c2: BinaryLabelCounter) => c1 += c2 - ).sortByKey(ascending = false) + ) + if (scoreLabelsWeight.getStorageLevel != StorageLevel.NONE) { + binnedWeights.persist() + } + val counts = binnedWeights.sortByKey(ascending = false) val binnedCounts = // Only down-sample if bins is > 0 @@ -215,6 +220,7 @@ class BinaryClassificationMetrics @Since("3.0.0") ( val partitionwiseCumulativeCounts = agg.scanLeft(new BinaryLabelCounter())((agg, c) => agg.clone() += c) val totalCount = partitionwiseCumulativeCounts.last + binnedWeights.unpersist() logInfo(s"Total counts: $totalCount") val cumulativeCounts = binnedCounts.mapPartitionsWithIndex( (index: Int, iter: Iterator[(Double, BinaryLabelCounter)]) => {