Skip to content

Commit fe2c424

Browse files
committed
add llk in summary
1 parent 8d8bb24 commit fe2c424

File tree

3 files changed

+14
-22
lines changed

3 files changed

+14
-22
lines changed

mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -129,25 +129,6 @@ class GaussianMixtureModel private[ml] (
129129
Vectors.dense(probs)
130130
}
131131

132-
/**
133-
* Return the total log-likelihood for this model on the given data.
134-
*/
135-
@Since("2.2.0")
136-
def computeLogLikelihood(dataset: Dataset[_]): Double = {
137-
SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT)
138-
val spark = dataset.sparkSession
139-
import spark.implicits._
140-
141-
val bcWeightAndDists = spark.sparkContext.broadcast(weights.zip(gaussians))
142-
dataset.select(col($(featuresCol))).map {
143-
case Row(feature: Vector) =>
144-
val likelihood = bcWeightAndDists.value.map {
145-
case (weight, dist) => EPSILON + weight * dist.pdf(feature)
146-
}.sum
147-
math.log(likelihood)
148-
}.reduce(_ + _)
149-
}
150-
151132
/**
152133
* Retrieve Gaussian distributions as a DataFrame.
153134
* Each row represents a Gaussian Distribution.
@@ -435,7 +416,7 @@ class GaussianMixture @Since("2.0.0") (
435416

436417
val model = copyValues(new GaussianMixtureModel(uid, weights, gaussianDists)).setParent(this)
437418
val summary = new GaussianMixtureSummary(model.transform(dataset),
438-
$(predictionCol), $(probabilityCol), $(featuresCol), $(k))
419+
$(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood)
439420
model.setSummary(Some(summary))
440421
instr.logSuccess(model)
441422
model
@@ -693,6 +674,7 @@ private class ExpectationAggregator(
693674
* in `predictions`.
694675
* @param featuresCol Name for column of features in `predictions`.
695676
* @param k Number of clusters.
677+
* @param logLikelihood Total log-likelihood for this model on the given data.
696678
*/
697679
@Since("2.0.0")
698680
@Experimental
@@ -701,7 +683,9 @@ class GaussianMixtureSummary private[clustering] (
701683
predictionCol: String,
702684
@Since("2.0.0") val probabilityCol: String,
703685
featuresCol: String,
704-
k: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k) {
686+
k: Int,
687+
@Since("2.2.0") val logLikelihood: Double)
688+
extends ClusteringSummary(predictions, predictionCol, featuresCol, k) {
705689

706690
/**
707691
* Probability of each cluster.

mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
224224
val actual = new GaussianMixture().setK(2).setSeed(seed).fit(rDataset)
225225
modelEquals(expected, actual)
226226

227-
val llk = expected.computeLogLikelihood(rDataset)
227+
val llk = actual.summary.logLikelihood
228228
assert(llk ~== -46.89499 absTol 1E-5)
229229
}
230230

python/pyspark/ml/clustering.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,14 @@ def probability(self):
281281
"""
282282
return self._call_java("probability")
283283

284+
@property
285+
@since("2.2.0")
286+
def logLikelihood(self):
287+
"""
288+
Total log-likelihood for this model on the given data.
289+
"""
290+
return self._call_java("logLikelihood")
291+
284292

285293
class KMeansSummary(ClusteringSummary):
286294
"""

0 commit comments

Comments
 (0)