Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,53 +34,123 @@ import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.storage.StorageLevel

/**
* Params for logistic regression.
*/
private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams
with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol
with HasStandardization {
with HasStandardization with HasThreshold {

/**
* Version of setThresholds() for binary classification, available for backwards
* compatibility.
* Set threshold in binary classification, in range [0, 1].
*
* Calling this with threshold p will effectively call `setThresholds(Array(1-p, p))`.
* If the estimated probability of class label 1 is > threshold, then predict 1, else 0.
* A high threshold encourages the model to predict 0 more often;
* a low threshold encourages the model to predict 1 more often.
*
* Note: Calling this with threshold p is equivalent to calling `setThresholds(Array(1-p, p))`.
* When [[setThreshold()]] is called, any user-set value for [[thresholds]] will be cleared.
* If both [[threshold]] and [[thresholds]] are set in a ParamMap, then they must be
* equivalent.
*
* Default is 0.5.
* @group setParam
*/
def setThreshold(value: Double): this.type = {
if (isSet(thresholds)) clear(thresholds)
set(threshold, value)
}

/**
* Get threshold for binary classification.
*
* If [[threshold]] is set, returns that value.
* Otherwise, if [[thresholds]] is set with length 2 (i.e., binary classification),
* this returns the equivalent threshold: {{{1 / (1 + thresholds(0) / thresholds(1))}}}.
* Otherwise, returns [[threshold]] default value.
*
* @group getParam
* @throws IllegalArgumentException if [[thresholds]] is set to an array of length other than 2.
*/
override def getThreshold: Double = {
checkThresholdConsistency()
if (isSet(thresholds)) {
val ts = $(thresholds)
require(ts.length == 2, "Logistic Regression getThreshold only applies to" +
" binary classification, but thresholds has length != 2. thresholds: " + ts.mkString(","))
1.0 / (1.0 + ts(0) / ts(1))
} else {
$(threshold)
}
}

/**
* Set thresholds in multiclass (or binary) classification to adjust the probability of
* predicting each class. Array must have length equal to the number of classes, with values >= 0.
* The class with largest value p/t is predicted, where p is the original probability of that
* class and t is the class' threshold.
*
* Note: When [[setThresholds()]] is called, any user-set value for [[threshold]] will be cleared.
* If both [[threshold]] and [[thresholds]] are set in a ParamMap, then they must be
* equivalent.
*
* Default is effectively 0.5.
* @group setParam
*/
def setThreshold(value: Double): this.type = set(thresholds, Array(1.0 - value, value))
def setThresholds(value: Array[Double]): this.type = {
if (isSet(threshold)) clear(threshold)
set(thresholds, value)
}

/**
* Version of [[getThresholds()]] for binary classification, available for backwards
* compatibility.
* Get thresholds for binary or multiclass classification.
*
* If [[thresholds]] is set, return its value.
* Otherwise, if [[threshold]] is set, return the equivalent thresholds for binary
* classification: (1-threshold, threshold).
* If neither are set, throw an exception.
*
* Param thresholds must have length 2 (or not be specified).
* This returns {{{1 / (1 + thresholds(0) / thresholds(1))}}}.
* @group getParam
*/
def getThreshold: Double = {
if (isDefined(thresholds)) {
val thresholdValues = $(thresholds)
assert(thresholdValues.length == 2, "Logistic Regression getThreshold only applies to" +
" binary classification, but thresholds has length != 2." +
s" thresholds: ${thresholdValues.mkString(",")}")
1.0 / (1.0 + thresholdValues(0) / thresholdValues(1))
override def getThresholds: Array[Double] = {
checkThresholdConsistency()
if (!isSet(thresholds) && isSet(threshold)) {
val t = $(threshold)
Array(1-t, t)
} else {
0.5
$(thresholds)
}
}

/**
* If [[threshold]] and [[thresholds]] are both set, ensures they are consistent.
* @throws IllegalArgumentException if [[threshold]] and [[thresholds]] are not equivalent
*/
protected def checkThresholdConsistency(): Unit = {
if (isSet(threshold) && isSet(thresholds)) {
val ts = $(thresholds)
require(ts.length == 2, "Logistic Regression found inconsistent values for threshold and" +
s" thresholds. Param threshold is set (${$(threshold)}), indicating binary" +
s" classification, but Param thresholds is set with length ${ts.length}." +
" Clear one Param value to fix this problem.")
val t = 1.0 / (1.0 + ts(0) / ts(1))
require(math.abs($(threshold) - t) < 1E-5, "Logistic Regression getThreshold found" +
s" inconsistent values for threshold (${$(threshold)}) and thresholds (equivalent to $t)")
}
}

override def validateParams(): Unit = {
checkThresholdConsistency()
}
}

/**
* :: Experimental ::
* Logistic regression.
* Currently, this class only supports binary classification.
* Currently, this class only supports binary classification. It will support multiclass
* in the future.
*/
@Experimental
class LogisticRegression(override val uid: String)
Expand Down Expand Up @@ -128,7 +198,7 @@ class LogisticRegression(override val uid: String)
* Whether to fit an intercept term.
* Default is true.
* @group setParam
* */
*/
def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
setDefault(fitIntercept -> true)

Expand All @@ -140,14 +210,18 @@ class LogisticRegression(override val uid: String)
* is applied. In R's GLMNET package, the default behavior is true as well.
* Default is true.
* @group setParam
* */
*/
def setStandardization(value: Boolean): this.type = set(standardization, value)
setDefault(standardization -> true)

override def setThreshold(value: Double): this.type = super.setThreshold(value)

override def getThreshold: Double = super.getThreshold

override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value)

override def getThresholds: Array[Double] = super.getThresholds

override protected def train(dataset: DataFrame): LogisticRegressionModel = {
// Extract columns from data. If dataset is persisted, do not persist oldDataset.
val instances = extractLabeledPoints(dataset).map {
Expand Down Expand Up @@ -314,6 +388,10 @@ class LogisticRegressionModel private[ml] (

override def getThreshold: Double = super.getThreshold

override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value)

override def getThresholds: Array[Double] = super.getThresholds

/** Margin (rawPrediction) for class label 1. For binary classification only. */
private val margin: Vector => Double = (features) => {
BLAS.dot(features, weights) + intercept
Expand Down Expand Up @@ -364,6 +442,7 @@ class LogisticRegressionModel private[ml] (
* The behavior of this can be adjusted using [[thresholds]].
*/
override protected def predict(features: Vector): Double = {
// Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden.
if (score(features) > getThreshold) 1 else 0
}

Expand Down Expand Up @@ -393,6 +472,7 @@ class LogisticRegressionModel private[ml] (
}

override protected def raw2prediction(rawPrediction: Vector): Double = {
// Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden.
val t = getThreshold
val rawThreshold = if (t == 0.0) {
Double.NegativeInfinity
Expand All @@ -405,6 +485,7 @@ class LogisticRegressionModel private[ml] (
}

override protected def probability2prediction(probability: Vector): Double = {
// Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden.
if (probability(1) > getThreshold) 1 else 0
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ private[shared] object SharedParamsCodeGen {
" These probabilities should be treated as confidences, not precise probabilities.",
Some("\"probability\"")),
ParamDesc[Double]("threshold",
"threshold in binary classification prediction, in range [0, 1]",
"threshold in binary classification prediction, in range [0, 1]", Some("0.5"),
isValid = "ParamValidators.inRange(0, 1)", finalMethods = false),
ParamDesc[Array[Double]]("thresholds", "Thresholds in multi-class classification" +
" to adjust the probability of predicting each class." +
" Array must have length equal to the number of classes, with values >= 0." +
" The class with largest value p/t is predicted, where p is the original probability" +
" of that class and t is the class' threshold.",
isValid = "(t: Array[Double]) => t.forall(_ >= 0)"),
isValid = "(t: Array[Double]) => t.forall(_ >= 0)", finalMethods = false),
ParamDesc[String]("inputCol", "input column name"),
ParamDesc[Array[String]]("inputCols", "input column names"),
ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ private[ml] trait HasProbabilityCol extends Params {
}

/**
* Trait for shared param threshold.
* Trait for shared param threshold (default: 0.5).
*/
private[ml] trait HasThreshold extends Params {

Expand All @@ -149,6 +149,8 @@ private[ml] trait HasThreshold extends Params {
*/
final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction, in range [0, 1]", ParamValidators.inRange(0, 1))

setDefault(threshold, 0.5)

/** @group getParam */
def getThreshold: Double = $(threshold)
}
Expand All @@ -165,7 +167,7 @@ private[ml] trait HasThresholds extends Params {
final val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.", (t: Array[Double]) => t.forall(_ >= 0))

/** @group getParam */
final def getThresholds: Array[Double] = $(thresholds)
def getThresholds: Array[Double] = $(thresholds)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,7 @@ public void logisticRegressionWithSetters() {
assert(r.getDouble(0) == 0.0);
}
// Call transform with params, and check that the params worked.
double[] thresholds = {1.0, 0.0};
model.transform(
dataset, model.thresholds().w(thresholds), model.probabilityCol().w("myProb"))
model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb"))
.registerTempTable("predNotAllZero");
DataFrame predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero");
boolean foundNonZero = false;
Expand All @@ -112,9 +110,8 @@ public void logisticRegressionWithSetters() {
assert(foundNonZero);

// Call fit() with new params, and check as many params as we can.
double[] thresholds2 = {0.6, 0.4};
LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1),
lr.thresholds().w(thresholds2), lr.probabilityCol().w("theProb"));
lr.threshold().w(0.4), lr.probabilityCol().w("theProb"));
LogisticRegression parent2 = (LogisticRegression) model2.parent();
assert(parent2.getMaxIter() == 5);
assert(parent2.getRegParam() == 0.1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,23 +94,40 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
test("setThreshold, getThreshold") {
val lr = new LogisticRegression
// default
withClue("LogisticRegression should not have thresholds set by default") {
intercept[java.util.NoSuchElementException] {
assert(lr.getThreshold === 0.5, "LogisticRegression.threshold should default to 0.5")
withClue("LogisticRegression should not have thresholds set by default.") {
intercept[java.util.NoSuchElementException] { // Note: The exception type may change in future
lr.getThresholds
}
}
// Set via thresholds.
// Set via threshold.
// Intuition: Large threshold or large thresholds(1) makes class 0 more likely.
lr.setThreshold(1.0)
assert(lr.getThresholds === Array(0.0, 1.0))
lr.setThreshold(0.0)
assert(lr.getThresholds === Array(1.0, 0.0))
lr.setThreshold(0.5)
assert(lr.getThresholds === Array(0.5, 0.5))
// Test getThreshold
lr.setThresholds(Array(0.3, 0.7))
// Set via thresholds
val lr2 = new LogisticRegression
lr2.setThresholds(Array(0.3, 0.7))
val expectedThreshold = 1.0 / (1.0 + 0.3 / 0.7)
assert(lr.getThreshold ~== expectedThreshold relTol 1E-7)
assert(lr2.getThreshold ~== expectedThreshold relTol 1E-7)
// thresholds and threshold must be consistent
lr2.setThresholds(Array(0.1, 0.2, 0.3))
withClue("getThreshold should throw error if thresholds has length != 2.") {
intercept[IllegalArgumentException] {
lr2.getThreshold
}
}
// thresholds and threshold must be consistent: values
withClue("fit with ParamMap should throw error if threshold, thresholds do not match.") {
intercept[IllegalArgumentException] {
val lr2model = lr2.fit(dataset,
lr2.thresholds -> Array(0.3, 0.7), lr2.threshold -> (expectedThreshold / 2.0))
lr2model.getThreshold
}
}
}

test("logistic regression doesn't fit intercept when fitIntercept is off") {
Expand Down Expand Up @@ -145,16 +162,16 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
s" ${predAllZero.count(_ === 0)} of ${dataset.count()} were 0.")
// Call transform with params, and check that the params worked.
val predNotAllZero =
model.transform(dataset, model.thresholds -> Array(1.0, 0.0),
model.transform(dataset, model.threshold -> 0.0,
model.probabilityCol -> "myProb")
.select("prediction", "myProb")
.collect()
.map { case Row(pred: Double, prob: Vector) => pred }
assert(predNotAllZero.exists(_ !== 0.0))

// Call fit() with new params, and check as many params as we can.
lr.setThresholds(Array(0.6, 0.4))
val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1,
lr.thresholds -> Array(0.6, 0.4),
lr.probabilityCol -> "theProb")
val parent2 = model2.parent.asInstanceOf[LogisticRegression]
assert(parent2.getMaxIter === 5)
Expand Down
Loading