Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid

/**
* param for metric name in evaluation (supports `"f1"` (default), `"precision"`, `"recall"`,
* `"weightedPrecision"`, `"weightedRecall"`)
* `"weightedPrecision"`, `"weightedRecall"`, `"accuracy"`)
* @group param
*/
@Since("1.5.0")
val metricName: Param[String] = {
val allowedParams = ParamValidators.inArray(Array("f1", "precision",
"recall", "weightedPrecision", "weightedRecall"))
"recall", "weightedPrecision", "weightedRecall", "accuracy"))
new Param(this, "metricName", "metric name in evaluation " +
"(f1|precision|recall|weightedPrecision|weightedRecall)", allowedParams)
"(f1|precision|recall|weightedPrecision|weightedRecall|accuracy)", allowedParams)
}

/** @group getParam */
Expand Down Expand Up @@ -86,18 +86,13 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
case "recall" => metrics.recall
case "weightedPrecision" => metrics.weightedPrecision
case "weightedRecall" => metrics.weightedRecall
case "accuracy" => metrics.accuracy
}
metric
}

@Since("1.5.0")
override def isLargerBetter: Boolean = $(metricName) match {
case "f1" => true
case "precision" => true
case "recall" => true
case "weightedPrecision" => true
case "weightedRecall" => true
}
override def isLargerBetter: Boolean = true

@Since("1.5.0")
override def copy(extra: ParamMap): MulticlassClassificationEvaluator = defaultCopy(extra)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl
* Returns precision
*/
@Since("1.1.0")
lazy val precision: Double = tpByClass.values.sum.toDouble / labelCount
@deprecated("Use accuracy.", "2.0.0")
lazy val precision: Double = accuracy

/**
* Returns recall
Expand All @@ -148,14 +149,24 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl
* of all false negatives)
*/
@Since("1.1.0")
lazy val recall: Double = precision
@deprecated("Use accuracy.", "2.0.0")
lazy val recall: Double = accuracy

/**
* Returns f-measure
* (equals to precision and recall because precision equals recall)
*/
@Since("1.1.0")
lazy val fMeasure: Double = precision
@deprecated("Use accuracy.", "2.0.0")
lazy val fMeasure: Double = accuracy

/**
* Returns accuracy
* (equals to the total number of correctly classified instances
* out of the total number of instances.)
*/
@Since("2.0.0")
lazy val accuracy: Double = tpByClass.values.sum.toDouble / labelCount

/**
* Returns weighted true positive rate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,12 @@ class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(math.abs(metrics.fMeasure(1.0, 2.0) - f2measure1) < delta)
assert(math.abs(metrics.fMeasure(2.0, 2.0) - f2measure2) < delta)

assert(math.abs(metrics.recall -
assert(math.abs(metrics.accuracy -
(2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1))) < delta)
assert(math.abs(metrics.recall - metrics.precision) < delta)
assert(math.abs(metrics.recall - metrics.fMeasure) < delta)
assert(math.abs(metrics.recall - metrics.weightedRecall) < delta)
assert(math.abs(metrics.accuracy - metrics.precision) < delta)
assert(math.abs(metrics.accuracy - metrics.recall) < delta)
assert(math.abs(metrics.accuracy - metrics.fMeasure) < delta)
assert(math.abs(metrics.accuracy - metrics.weightedRecall) < delta)
assert(math.abs(metrics.weightedFalsePositiveRate -
((4.0 / 9) * fpRate0 + (4.0 / 9) * fpRate1 + (1.0 / 9) * fpRate2)) < delta)
assert(math.abs(metrics.weightedPrecision -
Expand Down