File tree Expand file tree Collapse file tree 1 file changed +12
-5
lines changed
mllib/src/main/scala/org/apache/spark/ml/classification Expand file tree Collapse file tree 1 file changed +12
-5
lines changed Original file line number Diff line number Diff line change @@ -201,11 +201,18 @@ abstract class ProbabilisticClassificationModel[
201201 probability.argmax
202202 } else {
203203 val thresholds : Array [Double ] = getThresholds
204- val scaledProbability : Array [Double ] =
205- probability.toArray.zip(thresholds).map { case (p, t) =>
206- if (t == 0.0 ) Double .PositiveInfinity else p / t
207- }
208- Vectors .dense(scaledProbability).argmax
204+
205+ if (thresholds.contains(0.0 )) {
206+ val indices = thresholds.zipWithIndex.filter(_._1 == 0.0 ).map(_._2)
207+ val values = indices.map(probability.apply)
208+ Vectors .sparse(numClasses, indices, values).argmax
209+ } else {
210+ val scaledProbability : Array [Double ] =
211+ probability.toArray.zip(thresholds).map { case (p, t) =>
212+ if (t == 0.0 ) Double .PositiveInfinity else p / t
213+ }
214+ Vectors .dense(scaledProbability).argmax
215+ }
209216 }
210217 }
211218}
You can’t perform that action at this time.
0 commit comments