-
Notifications
You must be signed in to change notification settings - Fork 29k
[MLLIB] SPARK-4362: Added classProbabilities method for Naive Bayes #3626
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,7 @@ import org.apache.spark.SparkContext._ | |
| import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector} | ||
| import org.apache.spark.mllib.regression.LabeledPoint | ||
| import org.apache.spark.rdd.RDD | ||
| import scala.collection.mutable | ||
|
|
||
| /** | ||
| * Model for Naive Bayes Classifiers. | ||
|
|
@@ -65,6 +66,25 @@ class NaiveBayesModel private[mllib] ( | |
| override def predict(testData: Vector): Double = { | ||
| labels(brzArgmax(brzPi + brzTheta * testData.toBreeze)) | ||
| } | ||
|
|
||
| def classProbabilities(testData: RDD[Vector]): | ||
| RDD[scala.collection.Map[Double, Double]] = { | ||
| val bcModel = testData.context.broadcast(this) | ||
| testData.mapPartitions { iter => | ||
| val model = bcModel.value | ||
| iter.map(model.classProbabilities) | ||
| } | ||
| } | ||
|
|
||
| def classProbabilities(testData: Vector): scala.collection.Map[Double, Double] = { | ||
| val posteriors = (brzPi + brzTheta * testData.toBreeze) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These aren't quite the class probabilities as the expression doesn't have the probability of the evidence incorporated. This would work if you first normalized the probabilities to sum to 1. |
||
| val sum = posteriors.sum | ||
| val probs:mutable.Map[Double,Double] = | ||
| mutable.Map.empty[Double, Double] | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can fit on 1 line. |
||
| posteriors.foreachPair((k,v) => probs += (labels(k) -> v/sum)) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. scala style: space after comma + space around "/" operator |
||
| probs | ||
| } | ||
|
|
||
| } | ||
|
|
||
| /** | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Classes are index from 0, so would it work to return (for each instance) a Vector of probabilities rather than a Map[Double, Double]? That seems simpler (and more efficient).
Also, I have a PR [https://github.com//pull/3637] for a separate part of MLlib which adds a method like this for predicting class conditional probabilities. I'd like us to use the same name for the prediction method, but I'm open about choosing a name. I had used "predictProbabilities" to (a) have it start with "predict" like other prediction methods and (b) leave open the possibility of supporting a similar method for regression algorithms (which can predict probability distributions). But I'll agree "classProbabilities" is more specific. Do you have strong preferences?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay, I have no strong preference but "predictProbabilities" makes sense for consistency. I can make that change and the style ones mentioned.
My stats background is not super-strong, @jatinpreet seemed to imply there's a correctness issue with this PR. Can anyone comment on if I've got the math wrong?