Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg.{Matrices, Matrix}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame

/**
* ::Experimental::
Expand All @@ -33,6 +34,13 @@ import org.apache.spark.rdd.RDD
@Experimental
class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) {

/**
* An auxiliary constructor taking a DataFrame.
* @param predictionAndLabels a DataFrame with two double columns: prediction and label
*/
private[mllib] def this(predictionAndLabels: DataFrame) =
this(predictionAndLabels.map(r => (r.getDouble(0), r.getDouble(1))))

private lazy val labelCountByClass: Map[Double, Long] = predictionAndLabels.values.countByValue()
private lazy val labelCount: Long = labelCountByClass.values.sum
private lazy val tpByClass: Map[Double, Int] = predictionAndLabels
Expand Down
129 changes: 129 additions & 0 deletions python/pyspark/mllib/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,135 @@ def r2(self):
return self.call("r2")


class MulticlassMetrics(JavaModelWrapper):
"""
Evaluator for multiclass classification.

>>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0),
... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)])
>>> metrics = MulticlassMetrics(predictionAndLabels)
>>> metrics.falsePositiveRate(0.0)
0.2...
>>> metrics.precision(1.0)
0.75...
>>> metrics.recall(2.0)
1.0...
>>> metrics.fMeasure(0.0, 2.0)
0.52...
>>> metrics.precision()
0.66...
>>> metrics.recall()
0.66...
>>> metrics.weightedFalsePositiveRate
0.19...
>>> metrics.weightedPrecision
0.68...
>>> metrics.weightedRecall
0.66...
>>> metrics.weightedFMeasure()
0.66...
>>> metrics.weightedFMeasure(2.0)
0.65...
"""

def __init__(self, predictionAndLabels):
"""
:param predictionAndLabels an RDD of (prediction, label) pairs.
"""
sc = predictionAndLabels.ctx
sql_ctx = SQLContext(sc)
df = sql_ctx.createDataFrame(predictionAndLabels, schema=StructType([
StructField("prediction", DoubleType(), nullable=False),
StructField("label", DoubleType(), nullable=False)]))
java_class = sc._jvm.org.apache.spark.mllib.evaluation.MulticlassMetrics
java_model = java_class(df._jdf)
super(MulticlassMetrics, self).__init__(java_model)

def truePositiveRate(self, label):
"""
Returns true positive rate for a given label (category).
"""
return self.call("truePositiveRate", label)

def falsePositiveRate(self, label):
"""
Returns false positive rate for a given label (category).
"""
return self.call("falsePositiveRate", label)

def precision(self, label=None):
"""
Returns precision or precision for a given label (category) if specified.
"""
if label is None:
return self.call("precision")
else:
return self.call("precision", float(label))

def recall(self, label=None):
"""
Returns recall or recall for a given label (category) if specified.
"""
if label is None:
return self.call("recall")
else:
return self.call("recall", float(label))

def fMeasure(self, label=None, beta=None):
"""
Returns f-measure or f-measure for a given label (category) if specified.
"""
if beta is None:
if label is None:
return self.call("fMeasure")
else:
return self.call("fMeasure", label)
else:
if label is None:
raise Exception("If the beta parameter is specified, label can not be none")
else:
return self.call("fMeasure", label, beta)

@property
def weightedTruePositiveRate(self):
"""
Returns weighted true positive rate.
(equals to precision, recall and f-measure)
"""
return self.call("weightedTruePositiveRate")

@property
def weightedFalsePositiveRate(self):
"""
Returns weighted false positive rate.
"""
return self.call("weightedFalsePositiveRate")

@property
def weightedRecall(self):
"""
Returns weighted averaged recall.
(equals to precision, recall and f-measure)
"""
return self.call("weightedRecall")

@property
def weightedPrecision(self):
"""
Returns weighted averaged precision.
"""
return self.call("weightedPrecision")

def weightedFMeasure(self, beta=None):
"""
Returns weighted averaged f-measure.
"""
if beta is None:
return self.call("weightedFMeasure")
else:
return self.call("weightedFMeasure", beta)


def _test():
import doctest
from pyspark import SparkContext
Expand Down