Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@ import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.linalg.BLAS._
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.storage.StorageLevel

/**
Expand Down Expand Up @@ -252,7 +254,13 @@ class LogisticRegression(override val uid: String)

if (handlePersistence) instances.unpersist()

copyValues(new LogisticRegressionModel(uid, weights, intercept))
val model = copyValues(new LogisticRegressionModel(uid, weights, intercept))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please explain why this copyValues is necessary? and I'm unable to understand how $(probabilityCol) gives a string because when I do this.

val model = lr.fit(dataset)
$(lr.probabilityCol)

I get

error: not found: value $
$(probabilityCol)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

$ is defined in Params, which LogisticRegression mixes in via LogisticRegressionParams. See https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/param/params.scala#L463

Without copyValues, the model you return will not contain any non-default user-specified params (e.g. predictionCol).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see thanks !

val logRegSummary = new BinaryLogisticRegressionTrainingSummary(
model.transform(dataset),
$(probabilityCol),
$(labelCol),
objectiveHistory)
model.setSummary(logRegSummary)
}

override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra)
Expand Down Expand Up @@ -286,6 +294,38 @@ class LogisticRegressionModel private[ml] (

override val numClasses: Int = 2

private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None

/**
* Gets summary of model on training set. An exception is
* thrown if `trainingSummary == None`.
*/
def summary: LogisticRegressionTrainingSummary = trainingSummary match {
case Some(summ) => summ
case None =>
throw new SparkException(
"No training summary available for this LogisticRegressionModel",
new NullPointerException())
}

private[classification] def setSummary(
summary: LogisticRegressionTrainingSummary): this.type = {
this.trainingSummary = Some(summary)
this
}

/** Indicates whether a training summary exists for this model instance. */
def hasSummary: Boolean = trainingSummary.isDefined

/**
* Evaluates the model on a testset.
* @param dataset Test dataset to evaluate model on.
*/
// TODO: decide on a good name before exposing to public API
private[classification] def evaluate(dataset: DataFrame): LogisticRegressionSummary = {
new BinaryLogisticRegressionSummary(this.transform(dataset), $(probabilityCol), $(labelCol))
}

/**
* Predict label for the given feature vector.
* The behavior of this can be adjusted using [[threshold]].
Expand Down Expand Up @@ -407,6 +447,128 @@ private[classification] class MultiClassSummarizer extends Serializable {
}
}

/**
* Abstraction for multinomial Logistic Regression Training results.
*/
sealed trait LogisticRegressionTrainingSummary extends LogisticRegressionSummary {

/** objective function (scaled loss + regularization) at each iteration. */
def objectiveHistory: Array[Double]

/** Number of training iterations until termination */
def totalIterations: Int = objectiveHistory.length

}

/**
* Abstraction for Logistic Regression Results for a given model.
*/
sealed trait LogisticRegressionSummary extends Serializable {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This trait can have defs for predictions, probabilityCol, labelCol.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am surprised this works in spite of these being vals in the subclass.


/** Dataframe outputted by the model's `transform` method. */
def predictions: DataFrame

/** Field in "predictions" which gives the calibrated probability of each sample as a vector. */
def probabilityCol: String

/** Field in "predictions" which gives the the true label of each sample. */
def labelCol: String

}

/**
* :: Experimental ::
* Logistic regression training results.
* @param predictions dataframe outputted by the model's `transform` method.
* @param probabilityCol field in "predictions" which gives the calibrated probability of
* each sample as a vector.
* @param labelCol field in "predictions" which gives the true label of each sample.
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
*/
@Experimental
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scaladocs

class BinaryLogisticRegressionTrainingSummary private[classification] (
predictions: DataFrame,
probabilityCol: String,
labelCol: String,
val objectiveHistory: Array[Double])
extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol)
with LogisticRegressionTrainingSummary {

}

/**
* :: Experimental ::
* Binary Logistic regression results for a given model.
* @param predictions dataframe outputted by the model's `transform` method.
* @param probabilityCol field in "predictions" which gives the calibrated probability of
* each sample.
* @param labelCol field in "predictions" which gives the true label of each sample.
*/
@Experimental
class BinaryLogisticRegressionSummary private[classification] (
@transient override val predictions: DataFrame,
override val probabilityCol: String,
override val labelCol: String) extends LogisticRegressionSummary {

private val sqlContext = predictions.sqlContext
import sqlContext.implicits._

/**
* Returns a BinaryClassificationMetrics object.
*/
// TODO: Allow the user to vary the number of bins using a setBins method in
// BinaryClassificationMetrics. For now the default is set to 100.
@transient private val binaryMetrics = new BinaryClassificationMetrics(
predictions.select(probabilityCol, labelCol).map {
case Row(score: Vector, label: Double) => (score(1), label)
}, 100
)

/**
* Returns the receiver operating characteristic (ROC) curve,
* which is an Dataframe having two fields (FPR, TPR)
* with (0.0, 0.0) prepended and (1.0, 1.0) appended to it.
* @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic
*/
@transient lazy val roc: DataFrame = binaryMetrics.roc().toDF("FPR", "TPR")

/**
* Computes the area under the receiver operating characteristic (ROC) curve.
*/
lazy val areaUnderROC: Double = binaryMetrics.areaUnderROC()

/**
* Returns the precision-recall curve, which is an Dataframe containing
* two fields recall, precision with (0.0, 1.0) prepended to it.
*/
@transient lazy val pr: DataFrame = binaryMetrics.pr().toDF("recall", "precision")

/**
* Returns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0.
*/
@transient lazy val fMeasureByThreshold: DataFrame = {
binaryMetrics.fMeasureByThreshold().toDF("threshold", "F-Measure")
}

/**
* Returns a dataframe with two fields (threshold, precision) curve.
* Every possible probability obtained in transforming the dataset are used
* as thresholds used in calculating the precision.
*/
@transient lazy val precisionByThreshold: DataFrame = {
binaryMetrics.precisionByThreshold().toDF("threshold", "precision")
}

/**
* Returns a dataframe with two fields (threshold, recall) curve.
* Every possible probability obtained in transforming the dataset are used
* as thresholds used in calculating the recall.
*/
@transient lazy val recallByThreshold: DataFrame = {
binaryMetrics.recallByThreshold().toDF("threshold", "recall")
}
}

/**
* LogisticAggregator computes the gradient and loss for binary logistic loss function, as used
* in binary classification for samples in sparse or dense vector in a online fashion.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,4 +147,13 @@ public void logisticRegressionPredictorClassifierMethods() {
}
}
}

@Test
public void logisticRegressionTrainingSummary() {
LogisticRegression lr = new LogisticRegression();
LogisticRegressionModel model = lr.fit(dataset);

LogisticRegressionTrainingSummary summary = model.summary();
assert(summary.totalIterations() == summary.objectiveHistory().length);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,41 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
val weightsR = Vectors.dense(0.0, 0.0, 0.0, 0.0)

assert(model1.intercept ~== interceptR relTol 1E-5)
assert(model1.weights ~= weightsR absTol 1E-6)
assert(model1.weights ~== weightsR absTol 1E-6)
}

test("evaluate on test set") {
// Evaluate on test set should be same as that of the transformed training data.
val lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(1.0)
.setThreshold(0.6)
val model = lr.fit(dataset)
val summary = model.summary.asInstanceOf[BinaryLogisticRegressionSummary]

val sameSummary = model.evaluate(dataset).asInstanceOf[BinaryLogisticRegressionSummary]
assert(summary.areaUnderROC === sameSummary.areaUnderROC)
assert(summary.roc.collect() === sameSummary.roc.collect())
assert(summary.pr.collect === sameSummary.pr.collect())
assert(
summary.fMeasureByThreshold.collect() === sameSummary.fMeasureByThreshold.collect())
assert(summary.recallByThreshold.collect() === sameSummary.recallByThreshold.collect())
assert(
summary.precisionByThreshold.collect() === sameSummary.precisionByThreshold.collect())
}

test("statistics on training data") {
// Test that loss is monotonically decreasing.
val lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(1.0)
.setThreshold(0.6)
val model = lr.fit(dataset)
assert(
model.summary
.objectiveHistory
.sliding(2)
.forall(x => x(0) >= x(1)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: state what is being tested ("lossHistory is monotonically decreasing") in comment


Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Add more tests, after the first pass has been done.

}
}