Skip to content

Commit 4ec606d

Browse files
committed
deal with zero thresholds
1 parent 2a3d286 commit 4ec606d

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -201,11 +201,18 @@ abstract class ProbabilisticClassificationModel[
201201
probability.argmax
202202
} else {
203203
val thresholds: Array[Double] = getThresholds
204-
val scaledProbability: Array[Double] =
205-
probability.toArray.zip(thresholds).map { case (p, t) =>
206-
if (t == 0.0) Double.PositiveInfinity else p / t
207-
}
208-
Vectors.dense(scaledProbability).argmax
204+
205+
if (thresholds.contains(0.0)) {
206+
val indices = thresholds.zipWithIndex.filter(_._1 == 0.0).map(_._2)
207+
val values = indices.map(probability.apply)
208+
Vectors.sparse(numClasses, indices, values).argmax
209+
} else {
210+
val scaledProbability: Array[Double] =
211+
probability.toArray.zip(thresholds).map { case (p, t) =>
212+
if (t == 0.0) Double.PositiveInfinity else p / t
213+
}
214+
Vectors.dense(scaledProbability).argmax
215+
}
209216
}
210217
}
211218
}

0 commit comments

Comments
 (0)