Skip to content

Commit 31002e4

Browse files
sethahyanboliang
authored andcommitted
[SPARK-18282][ML][PYSPARK] Add python clustering summaries for GMM and BKM
## What changes were proposed in this pull request? Add model summary APIs for `GaussianMixtureModel` and `BisectingKMeansModel` in pyspark. ## How was this patch tested? Unit tests. Author: sethah <seth.hendrickson16@gmail.com> Closes #15777 from sethah/pyspark_cluster_summaries. (cherry picked from commit e811fbf) Signed-off-by: Yanbo Liang <ybliang8@gmail.com>
1 parent fb4e635 commit 31002e4

File tree

16 files changed

+256
-47
lines changed

16 files changed

+256
-47
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,7 @@ class LogisticRegression @Since("1.2.0") (
648648
$(labelCol),
649649
$(featuresCol),
650650
objectiveHistory)
651-
model.setSummary(logRegSummary)
651+
model.setSummary(Some(logRegSummary))
652652
} else {
653653
model
654654
}
@@ -790,9 +790,9 @@ class LogisticRegressionModel private[spark] (
790790
}
791791
}
792792

793-
private[classification] def setSummary(
794-
summary: LogisticRegressionTrainingSummary): this.type = {
795-
this.trainingSummary = Some(summary)
793+
private[classification]
794+
def setSummary(summary: Option[LogisticRegressionTrainingSummary]): this.type = {
795+
this.trainingSummary = summary
796796
this
797797
}
798798

@@ -887,8 +887,7 @@ class LogisticRegressionModel private[spark] (
887887
override def copy(extra: ParamMap): LogisticRegressionModel = {
888888
val newModel = copyValues(new LogisticRegressionModel(uid, coefficientMatrix, interceptVector,
889889
numClasses, isMultinomial), extra)
890-
if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get)
891-
newModel.setParent(parent)
890+
newModel.setSummary(trainingSummary).setParent(parent)
892891
}
893892

894893
override protected def raw2prediction(rawPrediction: Vector): Double = {

mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,7 @@ class BisectingKMeansModel private[ml] (
9595
@Since("2.0.0")
9696
override def copy(extra: ParamMap): BisectingKMeansModel = {
9797
val copied = copyValues(new BisectingKMeansModel(uid, parentModel), extra)
98-
if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get)
99-
copied.setParent(this.parent)
98+
copied.setSummary(trainingSummary).setParent(this.parent)
10099
}
101100

102101
@Since("2.0.0")
@@ -132,8 +131,8 @@ class BisectingKMeansModel private[ml] (
132131

133132
private var trainingSummary: Option[BisectingKMeansSummary] = None
134133

135-
private[clustering] def setSummary(summary: BisectingKMeansSummary): this.type = {
136-
this.trainingSummary = Some(summary)
134+
private[clustering] def setSummary(summary: Option[BisectingKMeansSummary]): this.type = {
135+
this.trainingSummary = summary
137136
this
138137
}
139138

@@ -265,7 +264,7 @@ class BisectingKMeans @Since("2.0.0") (
265264
val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this))
266265
val summary = new BisectingKMeansSummary(
267266
model.transform(dataset), $(predictionCol), $(featuresCol), $(k))
268-
model.setSummary(summary)
267+
model.setSummary(Some(summary))
269268
instr.logSuccess(model)
270269
model
271270
}

mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,7 @@ class GaussianMixtureModel private[ml] (
9090
@Since("2.0.0")
9191
override def copy(extra: ParamMap): GaussianMixtureModel = {
9292
val copied = copyValues(new GaussianMixtureModel(uid, weights, gaussians), extra)
93-
if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get)
94-
copied.setParent(this.parent)
93+
copied.setSummary(trainingSummary).setParent(this.parent)
9594
}
9695

9796
@Since("2.0.0")
@@ -150,8 +149,8 @@ class GaussianMixtureModel private[ml] (
150149

151150
private var trainingSummary: Option[GaussianMixtureSummary] = None
152151

153-
private[clustering] def setSummary(summary: GaussianMixtureSummary): this.type = {
154-
this.trainingSummary = Some(summary)
152+
private[clustering] def setSummary(summary: Option[GaussianMixtureSummary]): this.type = {
153+
this.trainingSummary = summary
155154
this
156155
}
157156

@@ -340,7 +339,7 @@ class GaussianMixture @Since("2.0.0") (
340339
.setParent(this)
341340
val summary = new GaussianMixtureSummary(model.transform(dataset),
342341
$(predictionCol), $(probabilityCol), $(featuresCol), $(k))
343-
model.setSummary(summary)
342+
model.setSummary(Some(summary))
344343
instr.logNumFeatures(model.gaussians.head.mean.size)
345344
instr.logSuccess(model)
346345
model

mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,7 @@ class KMeansModel private[ml] (
110110
@Since("1.5.0")
111111
override def copy(extra: ParamMap): KMeansModel = {
112112
val copied = copyValues(new KMeansModel(uid, parentModel), extra)
113-
if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get)
114-
copied.setParent(this.parent)
113+
copied.setSummary(trainingSummary).setParent(this.parent)
115114
}
116115

117116
/** @group setParam */
@@ -165,8 +164,8 @@ class KMeansModel private[ml] (
165164

166165
private var trainingSummary: Option[KMeansSummary] = None
167166

168-
private[clustering] def setSummary(summary: KMeansSummary): this.type = {
169-
this.trainingSummary = Some(summary)
167+
private[clustering] def setSummary(summary: Option[KMeansSummary]): this.type = {
168+
this.trainingSummary = summary
170169
this
171170
}
172171

@@ -325,7 +324,7 @@ class KMeans @Since("1.5.0") (
325324
val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
326325
val summary = new KMeansSummary(
327326
model.transform(dataset), $(predictionCol), $(featuresCol), $(k))
328-
model.setSummary(summary)
327+
model.setSummary(Some(summary))
329328
instr.logSuccess(model)
330329
model
331330
}

mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
270270
.setParent(this))
271271
val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model,
272272
wlsModel.diagInvAtWA.toArray, 1, getSolver)
273-
return model.setSummary(trainingSummary)
273+
return model.setSummary(Some(trainingSummary))
274274
}
275275

276276
// Fit Generalized Linear Model by iteratively reweighted least squares (IRLS).
@@ -284,7 +284,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
284284
.setParent(this))
285285
val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model,
286286
irlsModel.diagInvAtWA.toArray, irlsModel.numIterations, getSolver)
287-
model.setSummary(trainingSummary)
287+
model.setSummary(Some(trainingSummary))
288288
}
289289

290290
@Since("2.0.0")
@@ -761,8 +761,8 @@ class GeneralizedLinearRegressionModel private[ml] (
761761
def hasSummary: Boolean = trainingSummary.nonEmpty
762762

763763
private[regression]
764-
def setSummary(summary: GeneralizedLinearRegressionTrainingSummary): this.type = {
765-
this.trainingSummary = Some(summary)
764+
def setSummary(summary: Option[GeneralizedLinearRegressionTrainingSummary]): this.type = {
765+
this.trainingSummary = summary
766766
this
767767
}
768768

@@ -778,8 +778,7 @@ class GeneralizedLinearRegressionModel private[ml] (
778778
override def copy(extra: ParamMap): GeneralizedLinearRegressionModel = {
779779
val copied = copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept),
780780
extra)
781-
if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get)
782-
copied.setParent(parent)
781+
copied.setSummary(trainingSummary).setParent(parent)
783782
}
784783

785784
/**

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
225225
model.diagInvAtWA.toArray,
226226
model.objectiveHistory)
227227

228-
return lrModel.setSummary(trainingSummary)
228+
return lrModel.setSummary(Some(trainingSummary))
229229
}
230230

231231
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
@@ -278,7 +278,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
278278
model,
279279
Array(0D),
280280
Array(0D))
281-
return model.setSummary(trainingSummary)
281+
return model.setSummary(Some(trainingSummary))
282282
} else {
283283
require($(regParam) == 0.0, "The standard deviation of the label is zero. " +
284284
"Model cannot be regularized.")
@@ -400,7 +400,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
400400
model,
401401
Array(0D),
402402
objectiveHistory)
403-
model.setSummary(trainingSummary)
403+
model.setSummary(Some(trainingSummary))
404404
}
405405

406406
@Since("1.4.0")
@@ -446,8 +446,9 @@ class LinearRegressionModel private[ml] (
446446
throw new SparkException("No training summary available for this LinearRegressionModel")
447447
}
448448

449-
private[regression] def setSummary(summary: LinearRegressionTrainingSummary): this.type = {
450-
this.trainingSummary = Some(summary)
449+
private[regression]
450+
def setSummary(summary: Option[LinearRegressionTrainingSummary]): this.type = {
451+
this.trainingSummary = summary
451452
this
452453
}
453454

@@ -490,8 +491,7 @@ class LinearRegressionModel private[ml] (
490491
@Since("1.4.0")
491492
override def copy(extra: ParamMap): LinearRegressionModel = {
492493
val newModel = copyValues(new LinearRegressionModel(uid, coefficients, intercept), extra)
493-
if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get)
494-
newModel.setParent(parent)
494+
newModel.setSummary(trainingSummary).setParent(parent)
495495
}
496496

497497
/**

mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ class LogisticRegressionSuite
147147
assert(model.hasSummary)
148148
val copiedModel = model.copy(ParamMap.empty)
149149
assert(copiedModel.hasSummary)
150+
model.setSummary(None)
151+
assert(!model.hasSummary)
150152
}
151153

152154
test("empty probabilityCol") {

mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ class BisectingKMeansSuite
109109
assert(clusterSizes.length === k)
110110
assert(clusterSizes.sum === numRows)
111111
assert(clusterSizes.forall(_ >= 0))
112+
113+
model.setSummary(None)
114+
assert(!model.hasSummary)
112115
}
113116

114117
test("read/write") {

mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,9 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
111111
assert(clusterSizes.length === k)
112112
assert(clusterSizes.sum === numRows)
113113
assert(clusterSizes.forall(_ >= 0))
114+
115+
model.setSummary(None)
116+
assert(!model.hasSummary)
114117
}
115118

116119
test("read/write") {

mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
123123
assert(clusterSizes.length === k)
124124
assert(clusterSizes.sum === numRows)
125125
assert(clusterSizes.forall(_ >= 0))
126+
127+
model.setSummary(None)
128+
assert(!model.hasSummary)
126129
}
127130

128131
test("KMeansModel transform with non-default feature and prediction cols") {

0 commit comments

Comments
 (0)