From 6f90db19e64a7a8a29ddf024afed8bf0a51ab7d2 Mon Sep 17 00:00:00 2001 From: Ram Sriharsha Date: Wed, 20 May 2015 14:38:22 -0700 Subject: [PATCH 1/9] [SPARK-7574][ml][doc] User guide for OneVsRest --- .../sample_multiclass_classification_data.txt | 150 ++++++++++++++++++ docs/ml-ensembles.md | 73 +++++++++ docs/ml-guide.md | 3 +- 3 files changed, 225 insertions(+), 1 deletion(-) create mode 100644 data/mllib/sample_multiclass_classification_data.txt create mode 100644 docs/ml-ensembles.md diff --git a/data/mllib/sample_multiclass_classification_data.txt b/data/mllib/sample_multiclass_classification_data.txt new file mode 100644 index 000000000000..a0d7f9011391 --- /dev/null +++ b/data/mllib/sample_multiclass_classification_data.txt @@ -0,0 +1,150 @@ +1 1:-0.222222 2:0.5 3:-0.762712 4:-0.833333 +1 1:-0.555556 2:0.25 3:-0.864407 4:-0.916667 +1 1:-0.722222 2:-0.166667 3:-0.864407 4:-0.833333 +1 1:-0.722222 2:0.166667 3:-0.694915 4:-0.916667 +0 1:0.166667 2:-0.416667 3:0.457627 4:0.5 +1 1:-0.833333 3:-0.864407 4:-0.916667 +2 1:-1.32455e-07 2:-0.166667 3:0.220339 4:0.0833333 +2 1:-1.32455e-07 2:-0.333333 3:0.0169491 4:-4.03573e-08 +1 1:-0.5 2:0.75 3:-0.830508 4:-1 +0 1:0.611111 3:0.694915 4:0.416667 +0 1:0.222222 2:-0.166667 3:0.423729 4:0.583333 +1 1:-0.722222 2:-0.166667 3:-0.864407 4:-1 +1 1:-0.5 2:0.166667 3:-0.864407 4:-0.916667 +2 1:-0.222222 2:-0.333333 3:0.0508474 4:-4.03573e-08 +2 1:-0.0555556 2:-0.833333 3:0.0169491 4:-0.25 +2 1:-0.166667 2:-0.416667 3:-0.0169491 4:-0.0833333 +1 1:-0.944444 3:-0.898305 4:-0.916667 +2 1:-0.277778 2:-0.583333 3:-0.0169491 4:-0.166667 +0 1:0.111111 2:-0.333333 3:0.38983 4:0.166667 +2 1:-0.222222 2:-0.166667 3:0.0847457 4:-0.0833333 +0 1:0.166667 2:-0.333333 3:0.559322 4:0.666667 +1 1:-0.611111 2:0.0833333 3:-0.864407 4:-0.916667 +2 1:-0.333333 2:-0.583333 3:0.0169491 4:-4.03573e-08 +0 1:0.555555 2:-0.166667 3:0.661017 4:0.666667 +2 1:0.166667 3:0.186441 4:0.166667 +2 1:0.111111 2:-0.75 3:0.152542 4:-4.03573e-08 +2 1:0.166667 2:-0.25 3:0.118644 4:-4.03573e-08 +0 1:-0.0555556 2:-0.833333 3:0.355932 4:0.166667 +0 1:-0.277778 2:-0.333333 3:0.322034 4:0.583333 +2 1:-0.222222 2:-0.5 3:-0.152542 4:-0.25 +2 1:-0.111111 3:0.288136 4:0.416667 +2 1:-0.0555556 2:-0.25 3:0.186441 4:0.166667 +2 1:0.333333 2:-0.166667 3:0.355932 4:0.333333 +1 1:-0.611111 2:0.25 3:-0.898305 4:-0.833333 +0 1:0.166667 2:-0.333333 3:0.559322 4:0.75 +0 1:0.111111 2:-0.25 3:0.559322 4:0.416667 +0 1:0.833333 2:-0.166667 3:0.898305 4:0.666667 +2 1:-0.277778 2:-0.166667 3:0.186441 4:0.166667 +0 1:-0.666667 2:-0.583333 3:0.186441 4:0.333333 +1 1:-0.666667 2:-0.0833334 3:-0.830508 4:-1 +1 1:-0.166667 2:0.666667 3:-0.932203 4:-0.916667 +0 1:0.0555554 2:-0.333333 3:0.288136 4:0.416667 +1 1:-0.666667 2:-0.0833334 3:-0.830508 4:-1 +1 1:-0.833333 2:0.166667 3:-0.864407 4:-0.833333 +0 1:0.0555554 2:0.166667 3:0.491525 4:0.833333 +0 1:0.722222 2:-0.333333 3:0.728813 4:0.5 +2 1:-0.166667 2:-0.416667 3:0.0508474 4:-0.25 +2 1:0.5 3:0.254237 4:0.0833333 +0 1:0.111111 2:-0.583333 3:0.355932 4:0.5 +1 1:-0.944444 2:-0.166667 3:-0.898305 4:-0.916667 +2 1:0.277778 2:-0.25 3:0.220339 4:-4.03573e-08 +0 1:0.666667 2:-0.25 3:0.79661 4:0.416667 +0 1:0.111111 2:0.0833333 3:0.694915 4:1 +0 1:0.444444 3:0.59322 4:0.833333 +2 1:-0.0555556 2:0.166667 3:0.186441 4:0.25 +1 1:-0.833333 2:0.333333 3:-1 4:-0.916667 +1 1:-0.555556 2:0.416667 3:-0.830508 4:-0.75 +2 1:-0.333333 2:-0.5 3:0.152542 4:-0.0833333 +1 1:-1 2:-0.166667 3:-0.966102 4:-1 +1 1:-0.333333 2:0.25 3:-0.898305 4:-0.916667 +2 1:0.388889 2:-0.333333 3:0.288136 4:0.0833333 +2 1:0.277778 2:-0.166667 3:0.152542 4:0.0833333 +0 1:0.333333 2:0.0833333 3:0.59322 4:0.666667 +1 1:-0.777778 3:-0.79661 4:-0.916667 +1 1:-0.444444 2:0.416667 3:-0.830508 4:-0.916667 +0 1:0.222222 2:-0.166667 3:0.627119 4:0.75 +1 1:-0.555556 2:0.5 3:-0.79661 4:-0.916667 +1 1:-0.555556 2:0.5 3:-0.694915 4:-0.75 +2 1:-1.32455e-07 2:-0.25 3:0.254237 4:0.0833333 +1 1:-0.5 2:0.25 3:-0.830508 4:-0.916667 +0 1:0.166667 3:0.457627 4:0.833333 +2 1:0.444444 2:-0.0833334 3:0.322034 4:0.166667 +0 1:0.111111 2:0.166667 3:0.559322 4:0.916667 +1 1:-0.611111 2:0.25 3:-0.79661 4:-0.583333 +0 1:0.388889 3:0.661017 4:0.833333 +1 1:-0.722222 2:0.166667 3:-0.79661 4:-0.916667 +1 1:-0.722222 2:-0.0833334 3:-0.79661 4:-0.916667 +1 1:-0.555556 2:0.166667 3:-0.830508 4:-0.916667 +2 1:-0.666667 2:-0.666667 3:-0.220339 4:-0.25 +2 1:-0.611111 2:-0.75 3:-0.220339 4:-0.25 +2 1:0.0555554 2:-0.833333 3:0.186441 4:0.166667 +0 1:-0.166667 2:-0.416667 3:0.38983 4:0.5 +0 1:0.611111 2:0.333333 3:0.728813 4:1 +2 1:0.0555554 2:-0.25 3:0.118644 4:-4.03573e-08 +1 1:-0.666667 2:-0.166667 3:-0.864407 4:-0.916667 +1 1:-0.833333 2:-0.0833334 3:-0.830508 4:-0.916667 +0 1:0.611111 2:-0.166667 3:0.627119 4:0.25 +0 1:0.888889 2:0.5 3:0.932203 4:0.75 +2 1:0.222222 2:-0.333333 3:0.220339 4:0.166667 +1 1:-0.555556 2:0.25 3:-0.864407 4:-0.833333 +0 1:-1.32455e-07 2:-0.166667 3:0.322034 4:0.416667 +0 1:-1.32455e-07 2:-0.5 3:0.559322 4:0.0833333 +1 1:-0.611111 3:-0.932203 4:-0.916667 +1 1:-0.333333 2:0.833333 3:-0.864407 4:-0.916667 +0 1:-0.166667 2:-0.333333 3:0.38983 4:0.916667 +2 1:-0.333333 2:-0.666667 3:-0.0847458 4:-0.25 +2 1:-0.0555556 2:-0.416667 3:0.38983 4:0.25 +1 1:-0.388889 2:0.416667 3:-0.830508 4:-0.916667 +0 1:0.444444 2:-0.0833334 3:0.38983 4:0.833333 +1 1:-0.611111 2:0.333333 3:-0.864407 4:-0.916667 +0 1:0.111111 2:-0.416667 3:0.322034 4:0.416667 +0 1:0.166667 2:-0.0833334 3:0.525424 4:0.416667 +2 1:0.333333 2:-0.0833334 3:0.152542 4:0.0833333 +0 1:-0.0555556 2:-0.166667 3:0.288136 4:0.416667 +0 1:-0.166667 2:-0.416667 3:0.38983 4:0.5 +1 1:-0.611111 2:0.166667 3:-0.830508 4:-0.916667 +0 1:0.888889 2:-0.166667 3:0.728813 4:0.833333 +2 1:-0.277778 2:-0.25 3:-0.118644 4:-4.03573e-08 +2 1:-0.222222 2:-0.333333 3:0.186441 4:-4.03573e-08 +0 1:0.333333 2:-0.583333 3:0.627119 4:0.416667 +0 1:0.444444 2:-0.0833334 3:0.491525 4:0.666667 +2 1:-0.222222 2:-0.25 3:0.0847457 4:-4.03573e-08 +1 1:-0.611111 2:0.166667 3:-0.79661 4:-0.75 +2 1:-0.277778 2:-0.166667 3:0.0508474 4:-4.03573e-08 +0 1:1 2:0.5 3:0.830508 4:0.583333 +2 1:-0.333333 2:-0.666667 3:-0.0508475 4:-0.166667 +2 1:-0.277778 2:-0.416667 3:0.0847457 4:-4.03573e-08 +0 1:0.888889 2:-0.333333 3:0.932203 4:0.583333 +2 1:-0.111111 2:-0.166667 3:0.0847457 4:0.166667 +2 1:0.111111 2:-0.583333 3:0.322034 4:0.166667 +0 1:0.333333 2:0.0833333 3:0.59322 4:1 +0 1:0.222222 2:-0.166667 3:0.525424 4:0.416667 +1 1:-0.555556 2:0.5 3:-0.830508 4:-0.833333 +0 1:-0.111111 2:-0.166667 3:0.38983 4:0.416667 +0 1:0.888889 2:-0.5 3:1 4:0.833333 +1 1:-0.388889 2:0.583333 3:-0.898305 4:-0.75 +2 1:0.111111 2:0.0833333 3:0.254237 4:0.25 +0 1:0.333333 2:-0.166667 3:0.423729 4:0.833333 +1 1:-0.388889 2:0.166667 3:-0.762712 4:-0.916667 +0 1:0.333333 2:-0.0833334 3:0.559322 4:0.916667 +2 1:-0.333333 2:-0.75 3:0.0169491 4:-4.03573e-08 +1 1:-0.222222 2:1 3:-0.830508 4:-0.75 +1 1:-0.388889 2:0.583333 3:-0.762712 4:-0.75 +2 1:-0.611111 2:-1 3:-0.152542 4:-0.25 +2 1:-1.32455e-07 2:-0.333333 3:0.254237 4:-0.0833333 +2 1:-0.5 2:-0.416667 3:-0.0169491 4:0.0833333 +1 1:-0.888889 2:-0.75 3:-0.898305 4:-0.833333 +1 1:-0.666667 2:-0.0833334 3:-0.830508 4:-1 +2 1:-0.555556 2:-0.583333 3:-0.322034 4:-0.166667 +2 1:-0.166667 2:-0.5 3:0.0169491 4:-0.0833333 +1 1:-0.555556 2:0.0833333 3:-0.762712 4:-0.666667 +1 1:-0.777778 3:-0.898305 4:-0.916667 +0 1:0.388889 2:-0.166667 3:0.525424 4:0.666667 +0 1:0.222222 3:0.38983 4:0.583333 +2 1:0.333333 2:-0.0833334 3:0.254237 4:0.166667 +2 1:-0.388889 2:-0.166667 3:0.186441 4:0.166667 +0 1:-0.222222 2:-0.583333 3:0.355932 4:0.583333 +1 1:-0.611111 2:-0.166667 3:-0.79661 4:-0.916667 +1 1:-0.944444 2:-0.25 3:-0.864407 4:-0.916667 +1 1:-0.388889 2:0.166667 3:-0.830508 4:-0.75 diff --git a/docs/ml-ensembles.md b/docs/ml-ensembles.md new file mode 100644 index 000000000000..f32d1147a200 --- /dev/null +++ b/docs/ml-ensembles.md @@ -0,0 +1,73 @@ +--- +layout: global +title: Ensembles - MLlib +displayTitle: MLlib - Ensembles +--- + +* Table of contents +{:toc} + +An [ensemble method](http://en.wikipedia.org/wiki/Ensemble_learning) +is a learning algorithm which creates a model composed of a set of other base models. +ML supports the following ensemble algorithms: [`OneVsRest`](api/scala/index.html#org.apache.spark.ml.classifier.OneVsRest) + +## OneVsRest + +[`OneVsRest`](api/scala/index.html#org.apache.spark.ml.classifier.OneVsRest) is an example of a machine learning reduction for performing multiclass classification given a base classifier that can perform binary classification efficiently. + +`OneVsRest` is an `Estimator` takes as base classifier instances of [`Classifier`](api/scala/index.html#org.apache.spark.ml.classification.Classifier) and creates a binary classification problem for each of the k classes. The classifier for class i is trained to predict whether the label is i or not, distinguishing class i from all other classes. + +Predictions are done by evaluating each binary classifier and the index of the most confident classifier is output as label. + +### Example + +The example below demonstrates how to load a +[LIBSVM data file](http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/), parse it as an RDD of `LabeledPoint` and perform multiclass classification using `OneVsRest`. The test error is calculated to measure the algorithm accuracy. + +
+
+{% highlight scala %} +import org.apache.spark.ml.classification.{LogisticRegression, OneVsRest} +import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.functions._ + +val sqlContext = new SQLContext(sc) +import sqlContext.implicits._ + +// parse data into dataframe +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") +.toDF() +.withColumn("rnd", rand(0)) + +val train = data.filter($"rnd" < 0.8) +val test = data.filter($"rnd" >= 0.8) + +// instantiate multiclass learner and train +val ovr = new OneVsRest().setClassifier(new LogisticRegression) + +val ovrModel = ovr.fit(train) + +// score model on test data +val predictions = ovrModel.transform(test).select("prediction", "label") + +val predictionsRDD = predictions.map {case Row(p: Double, l: Double) => (p, l)} + +// compute confusion matrix +val metrics = new MulticlassMetrics(predictionsRDD) + +println(metrics.confusionMatrix) + +// compute the false positive rate per label +val predictionColSchema = predictions.schema("prediction") +val numClasses = MetadataUtils.getNumClasses(predictionColSchema).get +val fprs = Range(0, numClasses).map(p => (p, metrics.falsePositiveRate(p.toDouble))) + +println("label\tfpr") + +println(fprs.map {case (label, fpr) => label + "\t" + fpr}.mkString("\n")) + +{% endhighlight %} +
diff --git a/docs/ml-guide.md b/docs/ml-guide.md index b7b6376e061f..0e69e4f8f8a8 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -150,11 +150,12 @@ This is useful if there are two algorithms with the `maxIter` parameter in a `Pi # Algorithm Guides -There are now several algorithms in the Pipelines API which are not in the lower-level MLlib API, so we link to documentation for them here. These algorithms are mostly feature transformers, which fit naturally into the `Transformer` abstraction in Pipelines. +There are now several algorithms in the Pipelines API which are not in the lower-level MLlib API, so we link to documentation for them here. These algorithms are mostly feature transformers, which fit naturally into the `Transformer` abstraction in Pipelines and ensembles, which fit naturally into the `Estimator` abstraction in the Pipelines. **Pipelines API Algorithm Guides** * [Feature Extraction, Transformation, and Selection](ml-features.html) +* [Ensembles](ml-ensembles.html) # Code Examples From bb9dbfab77fd34a9dce213ab94672a9616895702 Mon Sep 17 00:00:00 2001 From: Ram Sriharsha Date: Wed, 20 May 2015 14:51:35 -0700 Subject: [PATCH 2/9] Clean up naming --- docs/ml-ensembles.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/ml-ensembles.md b/docs/ml-ensembles.md index f32d1147a200..175e8958088f 100644 --- a/docs/ml-ensembles.md +++ b/docs/ml-ensembles.md @@ -1,7 +1,7 @@ --- layout: global -title: Ensembles - MLlib -displayTitle: MLlib - Ensembles +title: Ensembles +displayTitle: ML - Ensembles --- * Table of contents From 13bed9c046a2f6faa20550d57e1c9ca0ec39068e Mon Sep 17 00:00:00 2001 From: Ram Sriharsha Date: Wed, 20 May 2015 14:59:19 -0700 Subject: [PATCH 3/9] add wikipedia link --- docs/ml-ensembles.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/ml-ensembles.md b/docs/ml-ensembles.md index 175e8958088f..42a42b67491f 100644 --- a/docs/ml-ensembles.md +++ b/docs/ml-ensembles.md @@ -13,9 +13,9 @@ ML supports the following ensemble algorithms: [`OneVsRest`](api/scala/index.htm ## OneVsRest -[`OneVsRest`](api/scala/index.html#org.apache.spark.ml.classifier.OneVsRest) is an example of a machine learning reduction for performing multiclass classification given a base classifier that can perform binary classification efficiently. +[OneVsRest](http://en.wikipedia.org/wiki/Multiclass_classification#One-vs.-rest) is an example of a machine learning reduction for performing multiclass classification given a base classifier that can perform binary classification efficiently. -`OneVsRest` is an `Estimator` takes as base classifier instances of [`Classifier`](api/scala/index.html#org.apache.spark.ml.classification.Classifier) and creates a binary classification problem for each of the k classes. The classifier for class i is trained to predict whether the label is i or not, distinguishing class i from all other classes. +[`OneVsRest`](api/scala/index.html#org.apache.spark.ml.classification.OneVsRest) is implemented as an `Estimator` takes as base classifier instances of [`Classifier`](api/scala/index.html#org.apache.spark.ml.classification.Classifier) and creates a binary classification problem for each of the k classes. The classifier for class i is trained to predict whether the label is i or not, distinguishing class i from all other classes. Predictions are done by evaluating each binary classifier and the index of the most confident classifier is output as label. From 4b7d1a6f49c937a0ab476fd735bfd63ed78e8e6d Mon Sep 17 00:00:00 2001 From: Ram Sriharsha Date: Wed, 20 May 2015 15:04:03 -0700 Subject: [PATCH 4/9] minor cleanup --- docs/ml-ensembles.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/ml-ensembles.md b/docs/ml-ensembles.md index 42a42b67491f..e6caa6582887 100644 --- a/docs/ml-ensembles.md +++ b/docs/ml-ensembles.md @@ -15,7 +15,7 @@ ML supports the following ensemble algorithms: [`OneVsRest`](api/scala/index.htm [OneVsRest](http://en.wikipedia.org/wiki/Multiclass_classification#One-vs.-rest) is an example of a machine learning reduction for performing multiclass classification given a base classifier that can perform binary classification efficiently. -[`OneVsRest`](api/scala/index.html#org.apache.spark.ml.classification.OneVsRest) is implemented as an `Estimator` takes as base classifier instances of [`Classifier`](api/scala/index.html#org.apache.spark.ml.classification.Classifier) and creates a binary classification problem for each of the k classes. The classifier for class i is trained to predict whether the label is i or not, distinguishing class i from all other classes. +`OneVsRest` is implemented as an `Estimator` takes as base classifier instances of `Classifier` and creates a binary classification problem for each of the k classes. The classifier for class i is trained to predict whether the label is i or not, distinguishing class i from all other classes. Predictions are done by evaluating each binary classifier and the index of the most confident classifier is output as label. From c0266138ba4e426b3e0b989537639af13e3ff3ed Mon Sep 17 00:00:00 2001 From: Ram Sriharsha Date: Wed, 20 May 2015 22:32:24 -0700 Subject: [PATCH 5/9] Code Review fixes --- docs/ml-ensembles.md | 37 ++++++++++++++----------------------- docs/ml-guide.md | 2 +- 2 files changed, 15 insertions(+), 24 deletions(-) diff --git a/docs/ml-ensembles.md b/docs/ml-ensembles.md index e6caa6582887..6133e244ffe8 100644 --- a/docs/ml-ensembles.md +++ b/docs/ml-ensembles.md @@ -4,46 +4,42 @@ title: Ensembles displayTitle: ML - Ensembles --- -* Table of contents +**Table of Contents** + +* This will become a table of contents (this text will be scraped). {:toc} An [ensemble method](http://en.wikipedia.org/wiki/Ensemble_learning) is a learning algorithm which creates a model composed of a set of other base models. -ML supports the following ensemble algorithms: [`OneVsRest`](api/scala/index.html#org.apache.spark.ml.classifier.OneVsRest) +The Pipelines API supports the following ensemble algorithms: [`OneVsRest`](api/scala/index.html#org.apache.spark.ml.classifier.OneVsRest) ## OneVsRest [OneVsRest](http://en.wikipedia.org/wiki/Multiclass_classification#One-vs.-rest) is an example of a machine learning reduction for performing multiclass classification given a base classifier that can perform binary classification efficiently. -`OneVsRest` is implemented as an `Estimator` takes as base classifier instances of `Classifier` and creates a binary classification problem for each of the k classes. The classifier for class i is trained to predict whether the label is i or not, distinguishing class i from all other classes. +`OneVsRest` is implemented as an `Estimator`. For the base classifier it takes instances of `Classifier` and creates a binary classification problem for each of the k classes. The classifier for class i is trained to predict whether the label is i or not, distinguishing class i from all other classes. Predictions are done by evaluating each binary classifier and the index of the most confident classifier is output as label. ### Example The example below demonstrates how to load a -[LIBSVM data file](http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/), parse it as an RDD of `LabeledPoint` and perform multiclass classification using `OneVsRest`. The test error is calculated to measure the algorithm accuracy. +[LIBSVM data file](http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/), parse it as a DataFrame and perform multiclass classification using `OneVsRest`. The test error is calculated to measure the algorithm accuracy.
{% highlight scala %} import org.apache.spark.ml.classification.{LogisticRegression, OneVsRest} -import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.functions._ val sqlContext = new SQLContext(sc) import sqlContext.implicits._ // parse data into dataframe val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") -.toDF() -.withColumn("rnd", rand(0)) - -val train = data.filter($"rnd" < 0.8) -val test = data.filter($"rnd" >= 0.8) +val Array(train, test) = data.toDF().randomSplit(Array(0.7, 0.3)) // instantiate multiclass learner and train val ovr = new OneVsRest().setClassifier(new LogisticRegression) @@ -52,22 +48,17 @@ val ovrModel = ovr.fit(train) // score model on test data val predictions = ovrModel.transform(test).select("prediction", "label") - -val predictionsRDD = predictions.map {case Row(p: Double, l: Double) => (p, l)} +val predictionsAndLabels = predictions.map {case Row(p: Double, l: Double) => (p, l)} // compute confusion matrix -val metrics = new MulticlassMetrics(predictionsRDD) - +val metrics = new MulticlassMetrics(predictionsAndLabels) println(metrics.confusionMatrix) -// compute the false positive rate per label -val predictionColSchema = predictions.schema("prediction") -val numClasses = MetadataUtils.getNumClasses(predictionColSchema).get -val fprs = Range(0, numClasses).map(p => (p, metrics.falsePositiveRate(p.toDouble))) - -println("label\tfpr") - -println(fprs.map {case (label, fpr) => label + "\t" + fpr}.mkString("\n")) +// the Iris DataSet has three classes +val numClasses = 3 +val fprs = (0 until numClasses).map(label => (label, metrics.falsePositiveRate(label.toDouble))) +println("label\tfpr\n%s".format(fprs.map {case (label, fpr) => label + "\t" + fpr}.mkString("\n"))) {% endhighlight %}
+
diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 0e69e4f8f8a8..ec3ae888027a 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -150,7 +150,7 @@ This is useful if there are two algorithms with the `maxIter` parameter in a `Pi # Algorithm Guides -There are now several algorithms in the Pipelines API which are not in the lower-level MLlib API, so we link to documentation for them here. These algorithms are mostly feature transformers, which fit naturally into the `Transformer` abstraction in Pipelines and ensembles, which fit naturally into the `Estimator` abstraction in the Pipelines. +There are now several algorithms in the Pipelines API which are not in the lower-level MLlib API, so we link to documentation for them here. These algorithms are mostly feature transformers, which fit naturally into the `Transformer` abstraction in Pipelines, and ensembles, which fit naturally into the `Estimator` abstraction in the Pipelines. **Pipelines API Algorithm Guides** From ebdf1035a64f0167dfe9a7a57b3bae1324114072 Mon Sep 17 00:00:00 2001 From: Ram Sriharsha Date: Fri, 22 May 2015 10:46:27 -0700 Subject: [PATCH 6/9] Java Example --- docs/ml-ensembles.md | 64 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/docs/ml-ensembles.md b/docs/ml-ensembles.md index 6133e244ffe8..86754fa8534b 100644 --- a/docs/ml-ensembles.md +++ b/docs/ml-ensembles.md @@ -61,4 +61,68 @@ val fprs = (0 until numClasses).map(label => (label, metrics.falsePositiveRate(l println("label\tfpr\n%s".format(fprs.map {case (label, fpr) => label + "\t" + fpr}.mkString("\n"))) {% endhighlight %}
+
+{% highlight java %} + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.classification.OneVsRest; +import org.apache.spark.mllib.evaluation.MulticlassMetrics; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.rdd.RDD; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; + +SparkConf conf = new SparkConf().setAppName("JavaOneVsRestExample"); +JavaSparkContext jsc = new JavaSparkContext(conf); +SQLContext jsql = new SQLContext(jsc); + +RDD data = MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_multiclass_classification_data.txt"); + +RDD[] split = data.randomSplit(new double[]{0.7, 0.3}, 12345); +RDD train = split[0]; +RDD test = split[1]; + +// instantiate the One Vs Rest Classifier +OneVsRest ovr = new OneVsRest().setClassifier(new LogisticRegression()); + +// train the multiclass model +DataFrame trainingDataFrame = jsql.createDataFrame(train, LabeledPoint.class); +OneVsRestModel ovrModel = ovr.fit(trainingDataFrame.cache()); + +// score the model on test data +DataFrame testDataFrame = jsql.createDataFrame(test, LabeledPoint.class); +DataFrame predictions = ovrModel + .transform(testDataFrame.cache()) + .select("prediction", "label"); + +// obtain metrics +MulticlassMetrics metrics = new MulticlassMetrics(predictions); +Matrix confusionMatrix = metrics.confusionMatrix(); + +// output the Confusion Matrix +System.out.println("Confusion Matrix"); +System.out.println(confusionMatrix); + +// compute the false positive rate per label +StringBuilder results = new StringBuilder(); +results.append("label\tfpr\n"); + +// the Iris DataSet has three classes +int numClasses = 3; + +for (int label = 0; label < numClasses; label++) { + results.append(label); + results.append("\t"); + results.append(metrics.falsePositiveRate((double) label)); + results.append("\n"); +} +System.out.println(); +System.out.println(results); +{% endhighlight %} +
From 2f762959808abb39eb2bb0597b79ea6e7b6885af Mon Sep 17 00:00:00 2001 From: Ram Sriharsha Date: Fri, 22 May 2015 11:00:56 -0700 Subject: [PATCH 7/9] Code Review Fixes --- docs/ml-ensembles.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/ml-ensembles.md b/docs/ml-ensembles.md index 86754fa8534b..37f976290fa9 100644 --- a/docs/ml-ensembles.md +++ b/docs/ml-ensembles.md @@ -23,8 +23,8 @@ Predictions are done by evaluating each binary classifier and the index of the m ### Example -The example below demonstrates how to load a -[LIBSVM data file](http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/), parse it as a DataFrame and perform multiclass classification using `OneVsRest`. The test error is calculated to measure the algorithm accuracy. +The example below demonstrates how to load the +[Iris dataset](http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/iris.scale), parse it as a DataFrame and perform multiclass classification using `OneVsRest`. The test error is calculated to measure the algorithm accuracy.
@@ -35,7 +35,6 @@ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.{Row, SQLContext} val sqlContext = new SQLContext(sc) -import sqlContext.implicits._ // parse data into dataframe val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") @@ -57,8 +56,8 @@ println(metrics.confusionMatrix) // the Iris DataSet has three classes val numClasses = 3 -val fprs = (0 until numClasses).map(label => (label, metrics.falsePositiveRate(label.toDouble))) -println("label\tfpr\n%s".format(fprs.map {case (label, fpr) => label + "\t" + fpr}.mkString("\n"))) +val fprs = (0 until numClasses).map(label => label + "\t" + metrics.falsePositiveRate(label.toDouble)).mkString("\n") +println("label\tfpr\n" + fprs) {% endhighlight %}
@@ -81,7 +80,8 @@ SparkConf conf = new SparkConf().setAppName("JavaOneVsRestExample"); JavaSparkContext jsc = new JavaSparkContext(conf); SQLContext jsql = new SQLContext(jsc); -RDD data = MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_multiclass_classification_data.txt"); +RDD data = MLUtils.loadLibSVMFile(jsc.sc(), + "data/mllib/sample_multiclass_classification_data.txt"); RDD[] split = data.randomSplit(new double[]{0.7, 0.3}, 12345); RDD train = split[0]; From 46c41b17c10a3efbb1ada91a8d1af72f629d2d13 Mon Sep 17 00:00:00 2001 From: Ram Sriharsha Date: Fri, 22 May 2015 11:42:05 -0700 Subject: [PATCH 8/9] cleanup --- docs/ml-ensembles.md | 47 ++++++++++++++++++++++---------------------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/docs/ml-ensembles.md b/docs/ml-ensembles.md index 37f976290fa9..0cb49681b36e 100644 --- a/docs/ml-ensembles.md +++ b/docs/ml-ensembles.md @@ -56,8 +56,11 @@ println(metrics.confusionMatrix) // the Iris DataSet has three classes val numClasses = 3 -val fprs = (0 until numClasses).map(label => label + "\t" + metrics.falsePositiveRate(label.toDouble)).mkString("\n") -println("label\tfpr\n" + fprs) +println("label\tfpr\n") +(0 until numClasses).foreach { index => + val label = index.toDouble + println(label + "\t" + metrics.falsePositiveRate(label)) +} {% endhighlight %}
@@ -67,38 +70,37 @@ import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.classification.OneVsRest; +import org.apache.spark.ml.classification.OneVsRestModel; import org.apache.spark.mllib.evaluation.MulticlassMetrics; import org.apache.spark.mllib.linalg.Matrix; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.util.MLUtils; import org.apache.spark.rdd.RDD; import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; SparkConf conf = new SparkConf().setAppName("JavaOneVsRestExample"); JavaSparkContext jsc = new JavaSparkContext(conf); SQLContext jsql = new SQLContext(jsc); - + RDD data = MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_multiclass_classification_data.txt"); -RDD[] split = data.randomSplit(new double[]{0.7, 0.3}, 12345); -RDD train = split[0]; -RDD test = split[1]; +DataFrame dataFrame = jsql.createDataFrame(data, LabeledPoint.class); +DataFrame[] splits = dataFrame.randomSplit(new double[]{0.7, 0.3}, 12345); +DataFrame train = splits[0]; +DataFrame test = splits[1]; // instantiate the One Vs Rest Classifier OneVsRest ovr = new OneVsRest().setClassifier(new LogisticRegression()); - + // train the multiclass model -DataFrame trainingDataFrame = jsql.createDataFrame(train, LabeledPoint.class); -OneVsRestModel ovrModel = ovr.fit(trainingDataFrame.cache()); +OneVsRestModel ovrModel = ovr.fit(train.cache()); // score the model on test data -DataFrame testDataFrame = jsql.createDataFrame(test, LabeledPoint.class); DataFrame predictions = ovrModel - .transform(testDataFrame.cache()) - .select("prediction", "label"); + .transform(test) + .select("prediction", "label"); // obtain metrics MulticlassMetrics metrics = new MulticlassMetrics(predictions); @@ -109,20 +111,19 @@ System.out.println("Confusion Matrix"); System.out.println(confusionMatrix); // compute the false positive rate per label -StringBuilder results = new StringBuilder(); -results.append("label\tfpr\n"); +System.out.println(); +System.out.println("label\tfpr\n"); // the Iris DataSet has three classes int numClasses = 3; - -for (int label = 0; label < numClasses; label++) { - results.append(label); - results.append("\t"); - results.append(metrics.falsePositiveRate((double) label)); - results.append("\n"); +for (int index = 0; index < numClasses; index++) { + double label = (double) index; + System.out.print(label); + System.out.print("\t"); + System.out.print(metrics.falsePositiveRate(label)); + System.out.println(); } -System.out.println(); -System.out.println(results); + {% endhighlight %}
From 645427cb1d6617c2acb73914f55cf20e276cae51 Mon Sep 17 00:00:00 2001 From: Ram Sriharsha Date: Fri, 22 May 2015 12:03:59 -0700 Subject: [PATCH 9/9] cleanup --- docs/ml-ensembles.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/ml-ensembles.md b/docs/ml-ensembles.md index 0cb49681b36e..9ff50e95fc47 100644 --- a/docs/ml-ensembles.md +++ b/docs/ml-ensembles.md @@ -37,7 +37,8 @@ import org.apache.spark.sql.{Row, SQLContext} val sqlContext = new SQLContext(sc) // parse data into dataframe -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") +val data = MLUtils.loadLibSVMFile(sc, + "data/mllib/sample_multiclass_classification_data.txt") val Array(train, test) = data.toDF().randomSplit(Array(0.7, 0.3)) // instantiate multiclass learner and train @@ -123,7 +124,6 @@ for (int index = 0; index < numClasses; index++) { System.out.print(metrics.falsePositiveRate(label)); System.out.println(); } - {% endhighlight %}