diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 19df8f7edd43..53a606d009f3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -201,11 +201,18 @@ abstract class ProbabilisticClassificationModel[ probability.argmax } else { val thresholds: Array[Double] = getThresholds - val scaledProbability: Array[Double] = - probability.toArray.zip(thresholds).map { case (p, t) => - if (t == 0.0) Double.PositiveInfinity else p / t - } - Vectors.dense(scaledProbability).argmax + + if (thresholds.contains(0.0)) { + val indices = thresholds.zipWithIndex.filter(_._1 == 0.0).map(_._2) + val values = indices.map(probability.apply) + Vectors.sparse(numClasses, indices, values).argmax + } else { + val scaledProbability: Array[Double] = + probability.toArray.zip(thresholds).map { case (p, t) => + if (t == 0.0) Double.PositiveInfinity else p / t + } + Vectors.dense(scaledProbability).argmax + } } } }