Skip to content

Commit 8ccca91

Browse files
zhengruifengyanboliang
authored andcommitted
[SPARK-14272][ML] Add Loglikelihood in GaussianMixtureSummary
## What changes were proposed in this pull request? add loglikelihood in GMM.summary ## How was this patch tested? added tests Author: Zheng RuiFeng <ruifengz@foxmail.com> Author: Ruifeng Zheng <ruifengz@foxmail.com> Closes #12064 from zhengruifeng/gmm_metric.
1 parent 2e62560 commit 8ccca91

File tree

5 files changed

+27
-4
lines changed

5 files changed

+27
-4
lines changed

mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ class GaussianMixture @Since("2.0.0") (
416416

417417
val model = copyValues(new GaussianMixtureModel(uid, weights, gaussianDists)).setParent(this)
418418
val summary = new GaussianMixtureSummary(model.transform(dataset),
419-
$(predictionCol), $(probabilityCol), $(featuresCol), $(k))
419+
$(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood)
420420
model.setSummary(Some(summary))
421421
instr.logSuccess(model)
422422
model
@@ -674,6 +674,7 @@ private class ExpectationAggregator(
674674
* in `predictions`.
675675
* @param featuresCol Name for column of features in `predictions`.
676676
* @param k Number of clusters.
677+
* @param logLikelihood Total log-likelihood for this model on the given data.
677678
*/
678679
@Since("2.0.0")
679680
@Experimental
@@ -682,7 +683,9 @@ class GaussianMixtureSummary private[clustering] (
682683
predictionCol: String,
683684
@Since("2.0.0") val probabilityCol: String,
684685
featuresCol: String,
685-
k: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k) {
686+
k: Int,
687+
@Since("2.2.0") val logLikelihood: Double)
688+
extends ClusteringSummary(predictions, predictionCol, featuresCol, k) {
686689

687690
/**
688691
* Probability of each cluster.

mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import org.apache.spark.storage.StorageLevel
3434
import org.apache.spark.util.random.BernoulliCellSampler
3535

3636
/**
37-
* Helper methods to load, save and pre-process data used in ML Lib.
37+
* Helper methods to load, save and pre-process data used in MLLib.
3838
*/
3939
@Since("0.8.0")
4040
object MLUtils extends Logging {

mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,10 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
207207
[,1] [,2]
208208
[1,] 0.2961543 0.160783
209209
[2,] 0.1607830 1.008878
210+
211+
model$loglik
212+
213+
[1] -46.89499
210214
*/
211215
val weights = Array(0.5333333, 0.4666667)
212216
val means = Array(Vectors.dense(10.363673, 9.897081), Vectors.dense(0.11731091, -0.06192351))
@@ -219,6 +223,9 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
219223
val expected = new GaussianMixtureModel("dummy", weights, gaussians)
220224
val actual = new GaussianMixture().setK(2).setSeed(seed).fit(rDataset)
221225
modelEquals(expected, actual)
226+
227+
val llk = actual.summary.logLikelihood
228+
assert(llk ~== -46.89499 absTol 1E-5)
222229
}
223230

224231
test("upper triangular matrix unpacking") {

project/MimaExcludes.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@ object MimaExcludes {
4646
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.streaming.scheduler.StreamingListener.onStreamingStarted"),
4747

4848
// [SPARK-19148][SQL] do not expose the external table concept in Catalog
49-
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.createTable")
49+
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.createTable"),
50+
51+
// [SPARK-14272][ML] Add logLikelihood in GaussianMixtureSummary
52+
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.GaussianMixtureSummary.this")
5053
)
5154

5255
// Exclude rules for 2.1.x

python/pyspark/ml/clustering.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@ class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte
175175
3
176176
>>> summary.clusterSizes
177177
[2, 2, 2]
178+
>>> summary.logLikelihood
179+
8.14636...
178180
>>> weights = model.weights
179181
>>> len(weights)
180182
3
@@ -281,6 +283,14 @@ def probability(self):
281283
"""
282284
return self._call_java("probability")
283285

286+
@property
287+
@since("2.2.0")
288+
def logLikelihood(self):
289+
"""
290+
Total log-likelihood for this model on the given data.
291+
"""
292+
return self._call_java("logLikelihood")
293+
284294

285295
class KMeansSummary(ClusteringSummary):
286296
"""

0 commit comments

Comments
 (0)