diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 8c8e4a161aa5b..76d2be3c632a3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -24,6 +24,7 @@ import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD +import scala.collection.mutable /** * Model for Naive Bayes Classifiers. @@ -65,6 +66,25 @@ class NaiveBayesModel private[mllib] ( override def predict(testData: Vector): Double = { labels(brzArgmax(brzPi + brzTheta * testData.toBreeze)) } + + def classProbabilities(testData: RDD[Vector]): + RDD[scala.collection.Map[Double, Double]] = { + val bcModel = testData.context.broadcast(this) + testData.mapPartitions { iter => + val model = bcModel.value + iter.map(model.classProbabilities) + } + } + + def classProbabilities(testData: Vector): scala.collection.Map[Double, Double] = { + val posteriors = (brzPi + brzTheta * testData.toBreeze) + val sum = posteriors.sum + val probs:mutable.Map[Double,Double] = + mutable.Map.empty[Double, Double] + posteriors.foreachPair((k,v) => probs += (labels(k) -> v/sum)) + probs + } + } /**