Skip to content
34 changes: 34 additions & 0 deletions docs/ml-classification-regression.md
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,40 @@ Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/OneVsRe
</div>
</div>

## Naive Bayes

[Naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier) are a family of simple
probabilistic classifiers based on applying Bayes' theorem with strong (naive) independence
assumptions between the features. The spark.ml implementation currently supports both [multinomial
naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html)
and [Bernoulli naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html).
More information can be found in the section on [Naive Bayes in MLlib](mllib-naive-bayes.html#naive-bayes-sparkmllib).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to clarify ml.NaiveBayes supports Multinomial NB and Bernoulli NB. Meanwhile, we should provide the link to corresponding documents. You can refer the NaiveBayes API doc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking a look. The wiki link already provided a good overall introduction to Naive Bayes. I'll add some clarification. And in the mllib documents, it clarifies naive Bayes supports both Multinomial and Bernoulli.

**Example**

<div class="codetabs">
<div data-lang="scala" markdown="1">

Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classification.NaiveBayes) for more details.

{% include_example scala/org/apache/spark/examples/ml/NaiveBayesExample.scala %}
</div>

<div data-lang="java" markdown="1">

Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/NaiveBayes.html) for more details.

{% include_example java/org/apache/spark/examples/ml/JavaNaiveBayesExample.java %}
</div>

<div data-lang="python" markdown="1">

Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classification.NaiveBayes) for more details.

{% include_example python/ml/naive_bayes_example.py %}
</div>
</div>


# Regression

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.examples.ml;


import org.apache.spark.SparkConf;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not include SparkConf or JavaSparkContext

import org.apache.spark.api.java.JavaSparkContext;
// $example on$
import org.apache.spark.ml.classification.NaiveBayes;
import org.apache.spark.ml.classification.NaiveBayesModel;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
// $example off$

/**
* An example for Naive Bayes Classification.
*/
public class JavaNaiveBayesExample {

public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("JavaNaiveBayesExample");
JavaSparkContext jsc = new JavaSparkContext(conf);
SQLContext jsql = new SQLContext(jsc);

// $example on$
// Load training data
Dataset<Row> dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
// Split the data into train and test
Dataset<Row>[] splits = dataFrame.randomSplit(new double[]{0.6, 0.4}, 1234L);
Dataset<Row> train = splits[0];
Dataset<Row> test = splits[1];

// create the trainer and set its parameters
NaiveBayes nb = new NaiveBayes();
// train the model
NaiveBayesModel model = nb.fit(train);
// compute precision on the test set
Dataset<Row> result = model.transform(test);
Dataset<Row> predictionAndLabels = result.select("prediction", "label");
MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
.setMetricName("precision");
System.out.println("Precision = " + evaluator.evaluate(predictionAndLabels));
// $example off$

jsc.stop();
}
}
53 changes: 53 additions & 0 deletions examples/src/main/python/ml/naive_bayes_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from __future__ import print_function

from pyspark import SparkContext
from pyspark.sql import SQLContext
# $example on$
from pyspark.ml.classification import NaiveBayes
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
# $example off$

if __name__ == "__main__":

sc = SparkContext(appName="naive_bayes_example")
sqlContext = SQLContext(sc)

# $example on$
# Load training data
data = sqlContext.read.format("libsvm") \
.load("data/mllib/sample_libsvm_data.txt")
# Split the data into train and test
splits = data.randomSplit([0.6, 0.4], 1234)
train = splits[0]
test = splits[1]

# create the trainer and set its parameters
nb = NaiveBayes(smoothing=1.0, modelType="multinomial")

# train the model
model = nb.fit(train)
# compute precision on the test set
result = model.transform(test)
predictionAndLabels = result.select("prediction", "label")
evaluator = MulticlassClassificationEvaluator(metricName="precision")
print("Precision:" + str(evaluator.evaluate(predictionAndLabels)))
# $example off$

sc.stop()
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// scalastyle:off println
package org.apache.spark.examples.ml

import org.apache.spark.{SparkConf, SparkContext}
// $example on$
import org.apache.spark.ml.classification.{NaiveBayes}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
// $example off$
import org.apache.spark.sql.SQLContext

object NaiveBayesExample {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("NaiveBayesExample")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
// $example on$
// Load the data stored in LIBSVM format as a DataFrame.
val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")

// Split the data into training and test sets (30% held out for testing)
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))

// Train a NaiveBayes model.
val model = new NaiveBayes()
.fit(trainingData)

// Select example rows to display.
val predictions = model.transform(testData)
predictions.show()

// Select (prediction, true label) and compute test error
val evaluator = new MulticlassClassificationEvaluator()
.setLabelCol("label")
.setPredictionCol("prediction")
.setMetricName("precision")
val precision = evaluator.evaluate(predictions)
println("Precision:" + precision)
// $example off$
}
}
// scalastyle:on println