Skip to content

Commit 47af0ac

Browse files
committed
update user guide for multinomial logistic regression
1 parent cdc2e15 commit 47af0ac

File tree

1 file changed

+185
-48
lines changed

1 file changed

+185
-48
lines changed

docs/mllib-linear-methods.md

Lines changed: 185 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ displayTitle: <a href="mllib-guide.html">MLlib</a> - Linear Methods
1717
\newcommand{\av}{\mathbf{\alpha}}
1818
\newcommand{\bv}{\mathbf{b}}
1919
\newcommand{\N}{\mathbb{N}}
20-
\newcommand{\id}{\mathbf{I}}
20+
\newcommand{\id}{\mathbf{I}}
2121
\newcommand{\ind}{\mathbf{1}}
2222
\newcommand{\0}{\mathbf{0}}
2323
\newcommand{\unit}{\mathbf{e}}
@@ -114,18 +114,26 @@ especially when the number of training examples is small.
114114

115115
Under the hood, linear methods use convex optimization methods to optimize the objective functions. MLlib uses two methods, SGD and L-BFGS, described in the [optimization section](mllib-optimization.html). Currently, most algorithm APIs support Stochastic Gradient Descent (SGD), and a few support L-BFGS. Refer to [this optimization section](mllib-optimization.html#Choosing-an-Optimization-Method) for guidelines on choosing between optimization methods.
116116

117-
## Binary classification
118-
119-
[Binary classification](http://en.wikipedia.org/wiki/Binary_classification)
120-
aims to divide items into two categories: positive and negative. MLlib
121-
supports two linear methods for binary classification: linear Support Vector
122-
Machines (SVMs) and logistic regression. For both methods, MLlib supports
123-
L1 and L2 regularized variants. The training data set is represented by an RDD
124-
of [LabeledPoint](mllib-data-types.html) in MLlib. Note that, in the
125-
mathematical formulation in this guide, a training label $y$ is denoted as
126-
either $+1$ (positive) or $-1$ (negative), which is convenient for the
127-
formulation. *However*, the negative label is represented by $0$ in MLlib
128-
instead of $-1$, to be consistent with multiclass labeling.
117+
## Classification
118+
119+
[Classification](http://en.wikipedia.org/wiki/Statistical_classification) aims to divide items into
120+
categories.
121+
The most common classification type is
122+
[binary classificaion](http://en.wikipedia.org/wiki/Binary_classification), where there are two
123+
categories, usually named positive and negative.
124+
If there are more than two categories, it is called
125+
[multiclass classification](http://en.wikipedia.org/wiki/Multiclass_classification).
126+
MLlib supports two linear methods for classification: linear Support Vector Machines (SVMs)
127+
and logistic regression.
128+
Linear SVMs supports only binary classification, while logistic regression supports both binary and
129+
multiclass classification problems.
130+
For both methods, MLlib supports L1 and L2 regularized variants.
131+
The training data set is represented by an RDD of [LabeledPoint](mllib-data-types.html) in MLlib,
132+
where labels are class indices starting from zero: $0, 1, 2, \ldots$.
133+
Note that, in the mathematical formulation in this guide, a binary label $y$ is denoted as either
134+
$+1$ (positive) or $-1$ (negative), which is convenient for the formulation.
135+
*However*, the negative label is represented by $0$ in MLlib instead of $-1$, to be consistent with
136+
multiclass labeling.
129137

130138
### Linear Support Vector Machines (SVMs)
131139

@@ -144,7 +152,7 @@ denoted by $\x$, the model makes predictions based on the value of $\wv^T \x$.
144152
By the default, if $\wv^T \x \geq 0$ then the outcome is positive, and negative
145153
otherwise.
146154

147-
#### Examples
155+
**Examples**
148156

149157
<div class="codetabs">
150158

@@ -213,8 +221,6 @@ svmAlg.optimizer.
213221
val modelL1 = svmAlg.run(training)
214222
{% endhighlight %}
215223

216-
[`LogisticRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithSGD) can be used in a similar fashion as `SVMWithSGD`.
217-
218224
</div>
219225

220226
<div data-lang="java" markdown="1">
@@ -347,18 +353,6 @@ with the loss function in the formulation given by the logistic loss:
347353
L(\wv;\x,y) := \log(1+\exp( -y \wv^T \x)).
348354
\]`
349355

350-
Binary logistic regression can be generalized into multinomial logistic regression to
351-
train and predict multi-class classification problems. For example, for $K$ possible outcomes,
352-
one of the outcomes can be chosen as a "pivot", and the other $K - 1$ outcomes can be separately
353-
regressed against the pivot outcome. In mllib, the first class, $0$ is chosen as "pivot" class.
354-
See $Eq.~(4.17)$ and $Eq.~(4.18)$ on page 119 of
355-
[The Elements of Statistical Learning: Data Mining, Inference, and Prediction, 2nd Edition]
356-
(http://statweb.stanford.edu/~tibs/ElemStatLearn/printings/ESLII_print10.pdf) by
357-
Trevor Hastie, Robert Tibshirani, and Jerome Friedman, and
358-
[Multinomial logistic regression](http://en.wikipedia.org/wiki/Multinomial_logistic_regression)
359-
for references. Here is [the detailed mathematical derivation]
360-
(http://www.slideshare.net/dbtsai/2014-0620-mlor-36132297).
361-
362356
For binary classification problems, the algorithm outputs a binary logistic regression model.
363357
Given a new data point, denoted by $\x$, the model makes predictions by
364358
applying the logistic function
@@ -371,27 +365,170 @@ negative otherwise, though unlike linear SVMs, the raw output of the logistic re
371365
model, $\mathrm{f}(z)$, has a probabilistic interpretation (i.e., the probability
372366
that $\x$ is positive).
373367

374-
For multi-class classification problems, the algorithm will outputs $K - 1$ binary
375-
logistic regression models regressed against the first class, $0$ as "pivot" outcome.
376-
Given a new data points, $K - 1$ models will be run, and the probabilities will be
377-
normalized into $1.0$. The class with largest probability will be chosen as output.
368+
Binary logistic regression can be generalized into
369+
[multinomial logistic regression](http://en.wikipedia.org/wiki/Multinomial_logistic_regression) to
370+
train and predict multiclass classification problems.
371+
For example, for $K$ possible outcomes, one of the outcomes can be chosen as a "pivot", and the
372+
other $K - 1$ outcomes can be separately regressed against the pivot outcome.
373+
In MLlib, the first class $0$ is chosen as the "pivot" class.
374+
See Section 4.4 of
375+
[The Elements of Statistical Learning](http://statweb.stanford.edu/~tibs/ElemStatLearn/) for
376+
references.
377+
Here is an
378+
[detailed mathematical derivation](http://www.slideshare.net/dbtsai/2014-0620-mlor-36132297).
379+
380+
For multiclass classification problems, the algorithm will outputs a multinomial logistic regression
381+
model, which contains $K - 1$ binary logistic regression models regressed against the first class.
382+
Given a new data points, $K - 1$ models will be run, and the class with largest probability will be
383+
chosen as the predicted class.
384+
385+
We implemented two algorithms to solve logistic regression: mini-batch gradient descent and L-BFGS.
386+
We recommend L-BFGS over mini-batch gradient descent for faster convergence.
387+
388+
**Examples**
389+
390+
<div class="codetabs">
391+
392+
<div data-lang="scala" markdown="1">
393+
The following code illustrates how to load a sample multiclass dataset, split it into train and
394+
test, and use
395+
[LogisticRegressionWithLBFGS](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS)
396+
to fit a logistic regression model.
397+
Then the model is evaluated against the test dataset and saved to disk.
398+
399+
{% highlight scala %}
400+
import org.apache.spark.SparkContext
401+
import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, LogisticRegressionModel}
402+
import org.apache.spark.mllib.evaluation.MulticlassMetrics
403+
import org.apache.spark.mllib.regression.LabeledPoint
404+
import org.apache.spark.mllib.linalg.Vectors
405+
import org.apache.spark.mllib.util.MLUtils
406+
407+
// Load training data in LIBSVM format.
408+
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
409+
410+
// Split data into training (60%) and test (40%).
411+
val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L)
412+
val training = splits(0).cache()
413+
val test = splits(1)
414+
415+
// Run training algorithm to build the model
416+
val model = new LogisticRegressionWithLBFGS()
417+
.setNumClasses(10)
418+
.run(training)
419+
420+
// Compute raw scores on the test set.
421+
val predictionAndLabels = test.map { case LabeledPoint(label, features) =>
422+
val prediction = model.predict(features)
423+
(prediction, label)
424+
}
425+
426+
// Get evaluation metrics.
427+
val metrics = new MulticlassMetrics(predictionAndLabels)
428+
val precision = metrics.precision
429+
println("Precision = " + precision)
430+
431+
// Save and load model
432+
model.save(sc, "myModelPath")
433+
val sameModel = LogisticRegressionModel.load(sc, "myModelPath")
434+
{% endhighlight %}
435+
436+
</div>
437+
438+
<div data-lang="java" markdown="1">
439+
The following code illustrates how to load a sample multiclass dataset, split it into train and
440+
test, and use
441+
[LogisticRegressionWithLBFGS](api/java/org/apache/spark/mllib/classification/LogisticRegressionWithLBFGS.html)
442+
to fit a logistic regression model.
443+
Then the model is evaluated against the test dataset and saved to disk.
444+
445+
{% highlight java %}
446+
import scala.Tuple2;
447+
448+
import org.apache.spark.api.java.*;
449+
import org.apache.spark.api.java.function.Function;
450+
import org.apache.spark.mllib.classification.LogisticRegressionModel;
451+
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS;
452+
import org.apache.spark.mllib.evaluation.MulticlassMetrics;
453+
import org.apache.spark.mllib.regression.LabeledPoint;
454+
import org.apache.spark.mllib.util.MLUtils;
455+
import org.apache.spark.SparkConf;
456+
import org.apache.spark.SparkContext;
457+
458+
public class MultinomialLogisticRegressionExample {
459+
public static void main(String[] args) {
460+
SparkConf conf = new SparkConf().setAppName("SVM Classifier Example");
461+
SparkContext sc = new SparkContext(conf);
462+
String path = "data/mllib/sample_libsvm_data.txt";
463+
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD();
378464

379-
#### Examples
465+
// Split initial RDD into two... [60% training data, 40% testing data].
466+
JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L);
467+
JavaRDD<LabeledPoint> training = splits[0].cache();
468+
JavaRDD<LabeledPoint> test = splits[1];
380469

470+
// Run training algorithm to build the model.
471+
final LogisticRegressionModel model = new LogisticRegressionWithLBFGS()
472+
.setNumClasses(10)
473+
.run(training.rdd());
381474

382-
### Evaluation metrics
475+
// Compute raw scores on the test set.
476+
JavaRDD<Tuple2<Object, Object>> predictionAndLabels = test.map(
477+
new Function<LabeledPoint, Tuple2<Object, Object>>() {
478+
public Tuple2<Object, Object> call(LabeledPoint p) {
479+
Double prediction = model.predict(p.features());
480+
return new Tuple2<Object, Object>(prediction, p.label());
481+
}
482+
}
483+
);
484+
485+
// Get evaluation metrics.
486+
MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd());
487+
double precision = metrics.precision();
488+
System.out.println("Precision = " + precision);
489+
490+
// Save and load model
491+
model.save(sc, "myModelPath");
492+
LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath");
493+
}
494+
}
495+
{% endhighlight %}
496+
</div>
497+
498+
<div data-lang="python" markdown="1">
499+
The following example shows how to load a sample dataset, build Logistic Regression model,
500+
and make predictions with the resulting model to compute the training error.
501+
502+
Note that the Python API does not yet support multiclass classification and model save/load but
503+
will in the future.
504+
505+
{% highlight python %}
506+
from pyspark.mllib.classification import LogisticRegressionWithLBFGS
507+
from pyspark.mllib.regression import LabeledPoint
508+
from numpy import array
509+
510+
# Load and parse the data
511+
def parsePoint(line):
512+
values = [float(x) for x in line.split(' ')]
513+
return LabeledPoint(values[0], values[1:])
514+
515+
data = sc.textFile("data/mllib/sample_svm_data.txt")
516+
parsedData = data.map(parsePoint)
517+
518+
# Build the model
519+
model = LogisticRegressionWithLBFGS.train(parsedData)
520+
521+
# Evaluating the model on training data
522+
labelsAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features)))
523+
trainErr = labelsAndPreds.filter(lambda (v, p): v != p).count() / float(parsedData.count())
524+
print("Training Error = " + str(trainErr))
525+
{% endhighlight %}
526+
</div>
527+
</div>
383528

384-
MLlib supports common evaluation metrics for binary classification (not available in PySpark).
385-
This
386-
includes precision, recall, [F-measure](http://en.wikipedia.org/wiki/F1_score),
387-
[receiver operating characteristic (ROC)](http://en.wikipedia.org/wiki/Receiver_operating_characteristic),
388-
precision-recall curve, and
389-
[area under the curves (AUC)](http://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve).
390-
AUC is commonly used to compare the performance of various models while
391-
precision/recall/F-measure can help determine the appropriate threshold to use
392-
for prediction purposes.
529+
# Regression
393530

394-
## Linear least squares, Lasso, and ridge regression
531+
### Linear least squares, Lasso, and ridge regression
395532

396533

397534
Linear least squares is the most common formulation for regression problems.
@@ -409,7 +546,7 @@ regularization; and [*Lasso*](http://en.wikipedia.org/wiki/Lasso_(statistics)) u
409546
regularization. For all of these models, the average loss or training error, $\frac{1}{n} \sum_{i=1}^n (\wv^T x_i - y_i)^2$, is
410547
known as the [mean squared error](http://en.wikipedia.org/wiki/Mean_squared_error).
411548

412-
### Examples
549+
**Examples**
413550

414551
<div class="codetabs">
415552

@@ -563,15 +700,15 @@ section of the Spark
563700
quick-start guide. Be sure to also include *spark-mllib* to your build file as
564701
a dependency.
565702

566-
## Streaming linear regression
703+
###Streaming linear regression
567704

568705
When data arrive in a streaming fashion, it is useful to fit regression models online,
569706
updating the parameters of the model as new data arrives. MLlib currently supports
570707
streaming linear regression using ordinary least squares. The fitting is similar
571708
to that performed offline, except fitting occurs on each batch of data, so that
572709
the model continually updates to reflect the data from the stream.
573710

574-
### Examples
711+
**Examples**
575712

576713
The following example demonstrates how to load training and testing data from two different
577714
input streams of text files, parse the streams as labeled points, fit a linear regression model
@@ -638,7 +775,7 @@ will get better!
638775
</div>
639776

640777

641-
## Implementation (developer)
778+
# Implementation (developer)
642779

643780
Behind the scene, MLlib implements a simple distributed version of stochastic gradient descent
644781
(SGD), building on the underlying gradient descent primitive (as described in the <a

0 commit comments

Comments
 (0)