Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/ml-classification-regression.md
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,13 @@ Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/OneVsRe

{% include_example java/org/apache/spark/examples/ml/JavaOneVsRestExample.java %}
</div>

<div data-lang="python" markdown="1">

Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classification.OneVsRest) for more details.

{% include_example python/ml/one_vs_rest_example.py %}
</div>
</div>

## Naive Bayes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,222 +17,68 @@

package org.apache.spark.examples.ml;

import org.apache.commons.cli.*;

// $example on$
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.OneVsRest;
import org.apache.spark.ml.classification.OneVsRestModel;
import org.apache.spark.ml.util.MetadataUtils;
import org.apache.spark.mllib.evaluation.MulticlassMetrics;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.StructField;
// $example off$
import org.apache.spark.sql.SparkSession;


/**
* An example runner for Multiclass to Binary Reduction with One Vs Rest.
* The example uses Logistic Regression as the base classifier. All parameters that
* can be specified on the base classifier can be passed in to the runner options.
* An example of Multiclass to Binary Reduction with One Vs Rest,
* using Logistic Regression as the base classifier.
* Run with
* <pre>
* bin/run-example ml.JavaOneVsRestExample [options]
* bin/run-example ml.JavaOneVsRestExample
* </pre>
*/
public class JavaOneVsRestExample {

private static class Params {
String input;
String testInput = null;
Integer maxIter = 100;
double tol = 1E-6;
boolean fitIntercept = true;
Double regParam = null;
Double elasticNetParam = null;
double fracTest = 0.2;
}

public static void main(String[] args) {
// parse the arguments
Params params = parse(args);
SparkSession spark = SparkSession
.builder()
.appName("JavaOneVsRestExample")
.getOrCreate();

// $example on$
// configure the base classifier
LogisticRegression classifier = new LogisticRegression()
.setMaxIter(params.maxIter)
.setTol(params.tol)
.setFitIntercept(params.fitIntercept);
// load data file.
Dataset<Row> inputData = spark.read().format("libsvm")
.load("data/mllib/sample_multiclass_classification_data.txt");

if (params.regParam != null) {
classifier.setRegParam(params.regParam);
}
if (params.elasticNetParam != null) {
classifier.setElasticNetParam(params.elasticNetParam);
}
// generate the train/test split.
Dataset<Row>[] tmp = inputData.randomSplit(new double[]{0.8, 0.2});
Dataset<Row> train = tmp[0];
Dataset<Row> test = tmp[1];

// instantiate the One Vs Rest Classifier
OneVsRest ovr = new OneVsRest().setClassifier(classifier);

String input = params.input;
Dataset<Row> inputData = spark.read().format("libsvm").load(input);
Dataset<Row> train;
Dataset<Row> test;
// configure the base classifier.
LogisticRegression classifier = new LogisticRegression()
.setMaxIter(10)
.setTol(1E-6)
.setFitIntercept(true);

// compute the train/ test split: if testInput is not provided use part of input
String testInput = params.testInput;
if (testInput != null) {
train = inputData;
// compute the number of features in the training set.
int numFeatures = inputData.first().<Vector>getAs(1).size();
test = spark.read().format("libsvm").option("numFeatures",
String.valueOf(numFeatures)).load(testInput);
} else {
double f = params.fracTest;
Dataset<Row>[] tmp = inputData.randomSplit(new double[]{1 - f, f}, 12345);
train = tmp[0];
test = tmp[1];
}
// instantiate the One Vs Rest Classifier.
OneVsRest ovr = new OneVsRest().setClassifier(classifier);

// train the multiclass model
OneVsRestModel ovrModel = ovr.fit(train.cache());
// train the multiclass model.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry one last thing - here we use train.cache() but we don't do that in the other examples. Actually in general we don't seem to do that in any other examples from a quick look. So perhaps remove that and just do ovr.fit(train);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. I will fix it

OneVsRestModel ovrModel = ovr.fit(train);

// score the model on test data
Dataset<Row> predictions = ovrModel.transform(test.cache())
// score the model on test data.
Dataset<Row> predictions = ovrModel.transform(test)
.select("prediction", "label");

// obtain metrics
MulticlassMetrics metrics = new MulticlassMetrics(predictions);
StructField predictionColSchema = predictions.schema().apply("prediction");
Integer numClasses = (Integer) MetadataUtils.getNumClasses(predictionColSchema).get();

// compute the false positive rate per label
StringBuilder results = new StringBuilder();
results.append("label\tfpr\n");
for (int label = 0; label < numClasses; label++) {
results.append(label);
results.append("\t");
results.append(metrics.falsePositiveRate((double) label));
results.append("\n");
}
// obtain evaluator.
MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
.setMetricName("precision");

Matrix confusionMatrix = metrics.confusionMatrix();
// output the Confusion Matrix
System.out.println("Confusion Matrix");
System.out.println(confusionMatrix);
System.out.println();
System.out.println(results);
// compute the classification error on test data.
double precision = evaluator.evaluate(predictions);
System.out.println("Test Error : " + (1 - precision));
// $example off$

spark.stop();
}

private static Params parse(String[] args) {
Options options = generateCommandlineOptions();
CommandLineParser parser = new PosixParser();
Params params = new Params();

try {
CommandLine cmd = parser.parse(options, args);
String value;
if (cmd.hasOption("input")) {
params.input = cmd.getOptionValue("input");
}
if (cmd.hasOption("maxIter")) {
value = cmd.getOptionValue("maxIter");
params.maxIter = Integer.parseInt(value);
}
if (cmd.hasOption("tol")) {
value = cmd.getOptionValue("tol");
params.tol = Double.parseDouble(value);
}
if (cmd.hasOption("fitIntercept")) {
value = cmd.getOptionValue("fitIntercept");
params.fitIntercept = Boolean.parseBoolean(value);
}
if (cmd.hasOption("regParam")) {
value = cmd.getOptionValue("regParam");
params.regParam = Double.parseDouble(value);
}
if (cmd.hasOption("elasticNetParam")) {
value = cmd.getOptionValue("elasticNetParam");
params.elasticNetParam = Double.parseDouble(value);
}
if (cmd.hasOption("testInput")) {
value = cmd.getOptionValue("testInput");
params.testInput = value;
}
if (cmd.hasOption("fracTest")) {
value = cmd.getOptionValue("fracTest");
params.fracTest = Double.parseDouble(value);
}

} catch (ParseException e) {
printHelpAndQuit(options);
}
return params;
}

@SuppressWarnings("static")
private static Options generateCommandlineOptions() {
Option input = OptionBuilder.withArgName("input")
.hasArg()
.isRequired()
.withDescription("input path to labeled examples. This path must be specified")
.create("input");
Option testInput = OptionBuilder.withArgName("testInput")
.hasArg()
.withDescription("input path to test examples")
.create("testInput");
Option fracTest = OptionBuilder.withArgName("testInput")
.hasArg()
.withDescription("fraction of data to hold out for testing." +
" If given option testInput, this option is ignored. default: 0.2")
.create("fracTest");
Option maxIter = OptionBuilder.withArgName("maxIter")
.hasArg()
.withDescription("maximum number of iterations for Logistic Regression. default:100")
.create("maxIter");
Option tol = OptionBuilder.withArgName("tol")
.hasArg()
.withDescription("the convergence tolerance of iterations " +
"for Logistic Regression. default: 1E-6")
.create("tol");
Option fitIntercept = OptionBuilder.withArgName("fitIntercept")
.hasArg()
.withDescription("fit intercept for logistic regression. default true")
.create("fitIntercept");
Option regParam = OptionBuilder.withArgName( "regParam" )
.hasArg()
.withDescription("the regularization parameter for Logistic Regression.")
.create("regParam");
Option elasticNetParam = OptionBuilder.withArgName("elasticNetParam" )
.hasArg()
.withDescription("the ElasticNet mixing parameter for Logistic Regression.")
.create("elasticNetParam");

Options options = new Options()
.addOption(input)
.addOption(testInput)
.addOption(fracTest)
.addOption(maxIter)
.addOption(tol)
.addOption(fitIntercept)
.addOption(regParam)
.addOption(elasticNetParam);

return options;
}

private static void printHelpAndQuit(Options options) {
HelpFormatter formatter = new HelpFormatter();
formatter.printHelp("JavaOneVsRestExample", options);
System.exit(-1);
}
}
68 changes: 68 additions & 0 deletions examples/src/main/python/ml/one_vs_rest_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from __future__ import print_function

# $example on$
from pyspark.ml.classification import LogisticRegression, OneVsRest
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
# $example off$
from pyspark.sql import SparkSession

"""
An example of Multiclass to Binary Reduction with One Vs Rest,
using Logistic Regression as the base classifier.
Run with:
bin/spark-submit examples/src/main/python/ml/one_vs_rest_example.py
"""


if __name__ == "__main__":
spark = SparkSession \
.builder \
.appName("PythonOneVsRestExample") \
.getOrCreate()

# $example on$
# load data file.
inputData = spark.read.format("libsvm") \
.load("data/mllib/sample_multiclass_classification_data.txt")

# generate the train/test split.
(train, test) = inputData.randomSplit([0.8, 0.2])

# instantiate the base classifier.
lr = LogisticRegression(maxIter=10, tol=1E-6, fitIntercept=True)

# instantiate the One Vs Rest Classifier.
ovr = OneVsRest(classifier=lr)

# train the multiclass model.
ovrModel = ovr.fit(train)

# score the model on test data.
predictions = ovrModel.transform(test)

# obtain evaluator.
evaluator = MulticlassClassificationEvaluator(metricName="precision")

# compute the classification error on test data.
precision = evaluator.evaluate(predictions)
print("Test Error : " + str(1 - precision))
# $example off$

spark.stop()
Loading