Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 151 additions & 0 deletions data/mllib/iris_data.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
4.7,3.2,1.3,0.2,Iris-setosa
4.6,3.1,1.5,0.2,Iris-setosa
5.0,3.6,1.4,0.2,Iris-setosa
5.4,3.9,1.7,0.4,Iris-setosa
4.6,3.4,1.4,0.3,Iris-setosa
5.0,3.4,1.5,0.2,Iris-setosa
4.4,2.9,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.4,3.7,1.5,0.2,Iris-setosa
4.8,3.4,1.6,0.2,Iris-setosa
4.8,3.0,1.4,0.1,Iris-setosa
4.3,3.0,1.1,0.1,Iris-setosa
5.8,4.0,1.2,0.2,Iris-setosa
5.7,4.4,1.5,0.4,Iris-setosa
5.4,3.9,1.3,0.4,Iris-setosa
5.1,3.5,1.4,0.3,Iris-setosa
5.7,3.8,1.7,0.3,Iris-setosa
5.1,3.8,1.5,0.3,Iris-setosa
5.4,3.4,1.7,0.2,Iris-setosa
5.1,3.7,1.5,0.4,Iris-setosa
4.6,3.6,1.0,0.2,Iris-setosa
5.1,3.3,1.7,0.5,Iris-setosa
4.8,3.4,1.9,0.2,Iris-setosa
5.0,3.0,1.6,0.2,Iris-setosa
5.0,3.4,1.6,0.4,Iris-setosa
5.2,3.5,1.5,0.2,Iris-setosa
5.2,3.4,1.4,0.2,Iris-setosa
4.7,3.2,1.6,0.2,Iris-setosa
4.8,3.1,1.6,0.2,Iris-setosa
5.4,3.4,1.5,0.4,Iris-setosa
5.2,4.1,1.5,0.1,Iris-setosa
5.5,4.2,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.0,3.2,1.2,0.2,Iris-setosa
5.5,3.5,1.3,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
4.4,3.0,1.3,0.2,Iris-setosa
5.1,3.4,1.5,0.2,Iris-setosa
5.0,3.5,1.3,0.3,Iris-setosa
4.5,2.3,1.3,0.3,Iris-setosa
4.4,3.2,1.3,0.2,Iris-setosa
5.0,3.5,1.6,0.6,Iris-setosa
5.1,3.8,1.9,0.4,Iris-setosa
4.8,3.0,1.4,0.3,Iris-setosa
5.1,3.8,1.6,0.2,Iris-setosa
4.6,3.2,1.4,0.2,Iris-setosa
5.3,3.7,1.5,0.2,Iris-setosa
5.0,3.3,1.4,0.2,Iris-setosa
7.0,3.2,4.7,1.4,Iris-versicolor
6.4,3.2,4.5,1.5,Iris-versicolor
6.9,3.1,4.9,1.5,Iris-versicolor
5.5,2.3,4.0,1.3,Iris-versicolor
6.5,2.8,4.6,1.5,Iris-versicolor
5.7,2.8,4.5,1.3,Iris-versicolor
6.3,3.3,4.7,1.6,Iris-versicolor
4.9,2.4,3.3,1.0,Iris-versicolor
6.6,2.9,4.6,1.3,Iris-versicolor
5.2,2.7,3.9,1.4,Iris-versicolor
5.0,2.0,3.5,1.0,Iris-versicolor
5.9,3.0,4.2,1.5,Iris-versicolor
6.0,2.2,4.0,1.0,Iris-versicolor
6.1,2.9,4.7,1.4,Iris-versicolor
5.6,2.9,3.6,1.3,Iris-versicolor
6.7,3.1,4.4,1.4,Iris-versicolor
5.6,3.0,4.5,1.5,Iris-versicolor
5.8,2.7,4.1,1.0,Iris-versicolor
6.2,2.2,4.5,1.5,Iris-versicolor
5.6,2.5,3.9,1.1,Iris-versicolor
5.9,3.2,4.8,1.8,Iris-versicolor
6.1,2.8,4.0,1.3,Iris-versicolor
6.3,2.5,4.9,1.5,Iris-versicolor
6.1,2.8,4.7,1.2,Iris-versicolor
6.4,2.9,4.3,1.3,Iris-versicolor
6.6,3.0,4.4,1.4,Iris-versicolor
6.8,2.8,4.8,1.4,Iris-versicolor
6.7,3.0,5.0,1.7,Iris-versicolor
6.0,2.9,4.5,1.5,Iris-versicolor
5.7,2.6,3.5,1.0,Iris-versicolor
5.5,2.4,3.8,1.1,Iris-versicolor
5.5,2.4,3.7,1.0,Iris-versicolor
5.8,2.7,3.9,1.2,Iris-versicolor
6.0,2.7,5.1,1.6,Iris-versicolor
5.4,3.0,4.5,1.5,Iris-versicolor
6.0,3.4,4.5,1.6,Iris-versicolor
6.7,3.1,4.7,1.5,Iris-versicolor
6.3,2.3,4.4,1.3,Iris-versicolor
5.6,3.0,4.1,1.3,Iris-versicolor
5.5,2.5,4.0,1.3,Iris-versicolor
5.5,2.6,4.4,1.2,Iris-versicolor
6.1,3.0,4.6,1.4,Iris-versicolor
5.8,2.6,4.0,1.2,Iris-versicolor
5.0,2.3,3.3,1.0,Iris-versicolor
5.6,2.7,4.2,1.3,Iris-versicolor
5.7,3.0,4.2,1.2,Iris-versicolor
5.7,2.9,4.2,1.3,Iris-versicolor
6.2,2.9,4.3,1.3,Iris-versicolor
5.1,2.5,3.0,1.1,Iris-versicolor
5.7,2.8,4.1,1.3,Iris-versicolor
6.3,3.3,6.0,2.5,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
7.1,3.0,5.9,2.1,Iris-virginica
6.3,2.9,5.6,1.8,Iris-virginica
6.5,3.0,5.8,2.2,Iris-virginica
7.6,3.0,6.6,2.1,Iris-virginica
4.9,2.5,4.5,1.7,Iris-virginica
7.3,2.9,6.3,1.8,Iris-virginica
6.7,2.5,5.8,1.8,Iris-virginica
7.2,3.6,6.1,2.5,Iris-virginica
6.5,3.2,5.1,2.0,Iris-virginica
6.4,2.7,5.3,1.9,Iris-virginica
6.8,3.0,5.5,2.1,Iris-virginica
5.7,2.5,5.0,2.0,Iris-virginica
5.8,2.8,5.1,2.4,Iris-virginica
6.4,3.2,5.3,2.3,Iris-virginica
6.5,3.0,5.5,1.8,Iris-virginica
7.7,3.8,6.7,2.2,Iris-virginica
7.7,2.6,6.9,2.3,Iris-virginica
6.0,2.2,5.0,1.5,Iris-virginica
6.9,3.2,5.7,2.3,Iris-virginica
5.6,2.8,4.9,2.0,Iris-virginica
7.7,2.8,6.7,2.0,Iris-virginica
6.3,2.7,4.9,1.8,Iris-virginica
6.7,3.3,5.7,2.1,Iris-virginica
7.2,3.2,6.0,1.8,Iris-virginica
6.2,2.8,4.8,1.8,Iris-virginica
6.1,3.0,4.9,1.8,Iris-virginica
6.4,2.8,5.6,2.1,Iris-virginica
7.2,3.0,5.8,1.6,Iris-virginica
7.4,2.8,6.1,1.9,Iris-virginica
7.9,3.8,6.4,2.0,Iris-virginica
6.4,2.8,5.6,2.2,Iris-virginica
6.3,2.8,5.1,1.5,Iris-virginica
6.1,2.6,5.6,1.4,Iris-virginica
7.7,3.0,6.1,2.3,Iris-virginica
6.3,3.4,5.6,2.4,Iris-virginica
6.4,3.1,5.5,1.8,Iris-virginica
6.0,3.0,4.8,1.8,Iris-virginica
6.9,3.1,5.4,2.1,Iris-virginica
6.7,3.1,5.6,2.4,Iris-virginica
6.9,3.1,5.1,2.3,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
6.8,3.2,5.9,2.3,Iris-virginica
6.7,3.3,5.7,2.5,Iris-virginica
6.7,3.0,5.2,2.3,Iris-virginica
6.3,2.5,5.0,1.9,Iris-virginica
6.5,3.0,5.2,2.0,Iris-virginica
6.2,3.4,5.4,2.3,Iris-virginica
5.9,3.0,5.1,1.8,Iris-virginica

163 changes: 120 additions & 43 deletions docs/mllib-linear-methods.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,41 +144,7 @@ denoted by $\x$, the model makes predictions based on the value of $\wv^T \x$.
By the default, if $\wv^T \x \geq 0$ then the outcome is positive, and negative
otherwise.

### Logistic regression

[Logistic regression](http://en.wikipedia.org/wiki/Logistic_regression) is widely used to predict a
binary response.
It is a linear method as described above in equation `$\eqref{eq:regPrimal}$`, with the loss
function in the formulation given by the logistic loss:
`\[
L(\wv;\x,y) := \log(1+\exp( -y \wv^T \x)).
\]`

The logistic regression algorithm outputs a logistic regression model. Given a
new data point, denoted by $\x$, the model makes predictions by
applying the logistic function
`\[
\mathrm{f}(z) = \frac{1}{1 + e^{-z}}
\]`
where $z = \wv^T \x$.
By default, if $\mathrm{f}(\wv^T x) > 0.5$, the outcome is positive, or
negative otherwise, though unlike linear SVMs, the raw output of the logistic regression
model, $\mathrm{f}(z)$, has a probabilistic interpretation (i.e., the probability
that $\x$ is positive).

### Evaluation metrics

MLlib supports common evaluation metrics for binary classification (not available in PySpark).
This
includes precision, recall, [F-measure](http://en.wikipedia.org/wiki/F1_score),
[receiver operating characteristic (ROC)](http://en.wikipedia.org/wiki/Receiver_operating_characteristic),
precision-recall curve, and
[area under the curves (AUC)](http://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve).
AUC is commonly used to compare the performance of various models while
precision/recall/F-measure can help determine the appropriate threshold to use
for prediction purposes.

### Examples
#### Examples

<div class="codetabs">

Expand Down Expand Up @@ -211,7 +177,7 @@ val model = SVMWithSGD.train(training, numIterations)
// Clear the default threshold.
model.clearThreshold()

// Compute raw scores on the test set.
// Compute raw scores on the test set.
val scoreAndLabels = test.map { point =>
val score = model.predict(point.features)
(score, point.label)
Expand Down Expand Up @@ -283,11 +249,11 @@ public class SVMClassifier {
JavaRDD<LabeledPoint> training = data.sample(false, 0.6, 11L);
training.cache();
JavaRDD<LabeledPoint> test = data.subtract(training);

// Run training algorithm to build the model.
int numIterations = 100;
final SVMModel model = SVMWithSGD.train(training.rdd(), numIterations);

// Clear the default threshold.
model.clearThreshold();

Expand All @@ -300,12 +266,12 @@ public class SVMClassifier {
}
}
);

// Get evaluation metrics.
BinaryClassificationMetrics metrics =
BinaryClassificationMetrics metrics =
new BinaryClassificationMetrics(JavaRDD.toRDD(scoreAndLabels));
double auROC = metrics.areaUnderROC();

System.out.println("Area under ROC = " + auROC);

model.save("myModelPath");
Expand Down Expand Up @@ -339,11 +305,95 @@ Applications](quick-start.html#self-contained-applications) section of the Spark
quick-start guide. Be sure to also include *spark-mllib* to your build file as
a dependency.
</div>
</div>

<div data-lang="python" markdown="1">
The following example shows how to load a sample dataset, build Logistic Regression model,
### Logistic regression

[Logistic regression](http://en.wikipedia.org/wiki/Logistic_regression) is widely used to predict a
binary response. It is a linear method as described above in equation `$\eqref{eq:regPrimal}$`,
with the loss function in the formulation given by the logistic loss:
`\[
L(\wv;\x,y) := \log(1+\exp( -y \wv^T \x)).
\]`

Binary logistic regression can be generalized into multinomial logistic regression to
train and predict multi-class classification problems. For example, for $K$ possible outcomes,
one of the outcomes can be chosen as a "pivot", and the other $K - 1$ outcomes can be separately
regressed against the pivot outcome. In mllib, the first class, $0$ is chosen as "pivot" class.
See $Eq.~(4.17)$ and $Eq.~(4.18)$ on page 119 of
[The Elements of Statistical Learning: Data Mining, Inference, and Prediction, 2nd Edition]
(http://statweb.stanford.edu/~tibs/ElemStatLearn/printings/ESLII_print10.pdf) by
Trevor Hastie, Robert Tibshirani, and Jerome Friedman, and
[Multinomial logistic regression](http://en.wikipedia.org/wiki/Multinomial_logistic_regression)
for references. Here is [the detailed mathematical derivation]
(http://www.slideshare.net/dbtsai/2014-0620-mlor-36132297).

For binary classification problems, the algorithm outputs a binary logistic regression model.
Given a new data point, denoted by $\x$, the model makes predictions by
applying the logistic function
`\[
\mathrm{f}(z) = \frac{1}{1 + e^{-z}}
\]`
where $z = \wv^T \x$.
By default, if $\mathrm{f}(\wv^T x) > 0.5$, the outcome is positive, or
negative otherwise, though unlike linear SVMs, the raw output of the logistic regression
model, $\mathrm{f}(z)$, has a probabilistic interpretation (i.e., the probability
that $\x$ is positive).

For multi-class classification problems, the algorithm will outputs $K - 1$ binary
logistic regression models regressed against the first class, $0$ as "pivot" outcome.
Given a new data points, $K - 1$ models will be run, and the probabilities will be
normalized into $1.0$. The class with largest probability will be chosen as output.

### Examples

The following example shows how to load a sample dataset, build Binary Logistic Regression model,
and make predictions with the resulting model to compute the training error.

<div class="codetabs">
<div data-lang="scala" markdown="1">
{% highlight scala %}
import org.apache.spark.SparkContext
import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD}
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLUtils

// Load training data in LIBSVM format.
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")

// Split data into training (60%) and test (40%).
val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L)
val training = splits(0).cache()
val test = splits(1)

// Run training algorithm to build the model
val numIterations = 100
val model = SVMWithSGD.train(training, numIterations)

// Clear the default threshold.
model.clearThreshold()

// Compute raw scores on the test set.
val scoreAndLabels = test.map { point =>
val score = model.predict(point.features)
(score, point.label)
}

// Get evaluation metrics.
val metrics = new BinaryClassificationMetrics(scoreAndLabels)
val auROC = metrics.areaUnderROC()

println("Area under ROC = " + auROC)

model.save("myModelPath")
val sameModel = SVMModel.load("myModelPath")
{% endhighlight %}
</div>

<div data-lang="python" markdown="1">

Note that the Python API does not yet support model save/load but will in the future.

{% highlight python %}
Expand All @@ -370,6 +420,23 @@ print("Training Error = " + str(trainErr))
</div>
</div>

The following example shows how to load a Iris dataset which has three classes, and then build
Binary Logistic Regression model,
and make predictions with the resulting model to compute the training error.


### Evaluation metrics

MLlib supports common evaluation metrics for binary classification (not available in PySpark).
This
includes precision, recall, [F-measure](http://en.wikipedia.org/wiki/F1_score),
[receiver operating characteristic (ROC)](http://en.wikipedia.org/wiki/Receiver_operating_characteristic),
precision-recall curve, and
[area under the curves (AUC)](http://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve).
AUC is commonly used to compare the performance of various models while
precision/recall/F-measure can help determine the appropriate threshold to use
for prediction purposes.

## Linear least squares, Lasso, and ridge regression


Expand Down Expand Up @@ -624,9 +691,19 @@ regularization parameter (`regParam`) along with various parameters associated w
gradient descent (`stepSize`, `numIterations`, `miniBatchFraction`). For each of them, we support
all three possible regularizations (none, L1 or L2).

For Logistic Regression, [L-BFGS](api/scala/index.html#org.apache.spark.mllib.optimization.LBFGS)
version is implemented under [LogisticRegressionWithLBFGS]
(api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS), and this
version supports both binary and multinomial Logistic Regression while SGD version only supports
binary Logistic Regression. However, L-BFGS version doesn't support L1 regularization but SGD one
supports L1 regularization. When L1 regularization is not required, L-BFGS version is strongly
recommended since it converges faster and more accurately compared to SGD by approximating the
inverse Hessian matrix using quasi-Newton method.

Algorithms are all implemented in Scala:

* [SVMWithSGD](api/scala/index.html#org.apache.spark.mllib.classification.SVMWithSGD)
* [LogisticRegressionWithLBFGS](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS)
* [LogisticRegressionWithSGD](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithSGD)
* [LinearRegressionWithSGD](api/scala/index.html#org.apache.spark.mllib.regression.LinearRegressionWithSGD)
* [RidgeRegressionWithSGD](api/scala/index.html#org.apache.spark.mllib.regression.RidgeRegressionWithSGD)
Expand Down