diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 19df8f7edd43..a83d98e246fe 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors, VectorUDT} +import org.apache.spark.ml.linalg.{DenseVector, Vector, VectorUDT} import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.sql.{DataFrame, Dataset} @@ -193,19 +193,24 @@ abstract class ProbabilisticClassificationModel[ /** * Given a vector of class conditional probabilities, select the predicted label. - * This supports thresholds which favor particular labels. - * @return predicted label + * This returns the class, if any, whose probability is equal to or greater than its + * threshold (if specified), and whose probability is highest. If several classes meet + * their thresholds and are equally probable, the one with lower threshold is selected. + * If several have equal thresholds, the one with lower class index is selected. + * + * @return predicted label, or NaN if no label is predicted */ protected def probability2prediction(probability: Vector): Double = { - if (!isDefined(thresholds)) { - probability.argmax + val prob = probability.toArray + if (isDefined(thresholds)) { + val candidates = prob.zip(getThresholds).zipWithIndex.filter { case ((p, t), _) => p >= t } + if (candidates.isEmpty) { + Double.NaN + } else { + candidates.maxBy { case ((p, t), i) => (p, -t, -i) }._2 + } } else { - val thresholds: Array[Double] = getThresholds - val scaledProbability: Array[Double] = - probability.toArray.zip(thresholds).map { case (p, t) => - if (t == 0.0) Double.PositiveInfinity else p / t - } - Vectors.dense(scaledProbability).argmax + prob.zipWithIndex.maxBy { case (p, i) => (p, -i) }._2 } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultinomialLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultinomialLogisticRegressionSuite.scala index 0913fe559c56..9952578d1abd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultinomialLogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultinomialLogisticRegressionSuite.scala @@ -988,22 +988,22 @@ class MultinomialLogisticRegressionSuite val basePredictions = model.transform(dataset).select("prediction").collect() // should predict all zeros - model.setThresholds(Array(1, 1000, 1000)) + model.setThresholds(Array(0, 1, 1)) val zeroPredictions = model.transform(dataset).select("prediction").collect() assert(zeroPredictions.forall(_.getDouble(0) === 0.0)) // should predict all ones - model.setThresholds(Array(1000, 1, 1000)) + model.setThresholds(Array(1, 0, 1)) val onePredictions = model.transform(dataset).select("prediction").collect() assert(onePredictions.forall(_.getDouble(0) === 1.0)) // should predict all twos - model.setThresholds(Array(1000, 1000, 1)) + model.setThresholds(Array(1, 1, 0)) val twoPredictions = model.transform(dataset).select("prediction").collect() assert(twoPredictions.forall(_.getDouble(0) === 2.0)) // constant threshold scaling is the same as no thresholds - model.setThresholds(Array(1000, 1000, 1000)) + model.setThresholds(Array(0.1, 0.1, 0.1)) val scaledPredictions = model.transform(dataset).select("prediction").collect() assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) => scaled.getDouble(0) === base.getDouble(0) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala index b3bd2b3e57b3..7f92e449a0b5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala @@ -36,8 +36,8 @@ final class TestProbabilisticClassificationModel( rawPrediction } - def friendlyPredict(input: Vector): Double = { - predict(input) + def friendlyPredict(input: Double*): Double = { + predict(Vectors.dense(input.toArray)) } } @@ -45,17 +45,44 @@ final class TestProbabilisticClassificationModel( class ProbabilisticClassifierSuite extends SparkFunSuite { test("test thresholding") { - val thresholds = Array(0.5, 0.2) val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2) - .setThresholds(thresholds) - assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 1.0))) === 1.0) - assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 0.2))) === 0.0) + .setThresholds(Array(0.5, 0.2)) + // Both exceed threshold; pick more probable one + assert(testModel.friendlyPredict(0.8, 0.9) === 1.0) + assert(testModel.friendlyPredict(1.0, 0.2) === 0.0) + // Tie; take one with lower threshold + assert(testModel.friendlyPredict(0.8, 0.8) === 1.0) + // Tie at 1 + assert(testModel.friendlyPredict(1.0, 1.0) === 1.0) + // Class 0 more probable but doesn't meet threshold + assert(testModel.friendlyPredict(0.4, 0.3) === 1.0) + // Neither meets threshold + assert(testModel.friendlyPredict(0.4, 0.1).isNaN) + assert(testModel.friendlyPredict(0.0, 0.0).isNaN) } - test("test thresholding not required") { + test("test equals thresholds") { val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2) - assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 2.0))) === 1.0) + .setThresholds(Array(0.5, 0.5)) + // Both exceed threshold; pick more probable one + assert(testModel.friendlyPredict(0.8, 0.9) === 1.0) + // Tie; take one with lower class + assert(testModel.friendlyPredict(0.8, 0.8) === 0.0) + assert(testModel.friendlyPredict(0.5, 0.5) === 0.0) + // Neither meets threshold + assert(testModel.friendlyPredict(0.4, 0.1).isNaN) } + + test("test no thresholding") { + val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2) + // Pick more probable class + assert(testModel.friendlyPredict(1.0, 2.0) === 1.0) + // Tie, pick first class + assert(testModel.friendlyPredict(1.0, 1.0) === 0.0) + assert(testModel.friendlyPredict(0.5, 0.5) === 0.0) + assert(testModel.friendlyPredict(0.0, 0.0) === 0.0) + } + } object ProbabilisticClassifierSuite {