diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index b29b5ac70e6f..d05d737a4e9c 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -296,6 +296,17 @@ def probability(self): return self._call_java("probability") +class KMeansSummary(ClusteringSummary): + """ + .. note:: Experimental + + Summary of KMeans. + + .. versionadded:: 2.1.0 + """ + pass + + class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable): """ Model fitted by KMeans. @@ -316,6 +327,27 @@ def computeCost(self, dataset): """ return self._call_java("computeCost", dataset) + @property + @since("2.1.0") + def hasSummary(self): + """ + Indicates whether a training summary exists for this model instance. + """ + return self._call_java("hasSummary") + + @property + @since("2.1.0") + def summary(self): + """ + Gets summary (e.g. cluster assignments, cluster sizes) of the model trained on the + training set. An exception is thrown if no summary exists. + """ + if self.hasSummary: + return KMeansSummary(self._call_java("summary")) + else: + raise RuntimeError("No training summary available for this %s" % + self.__class__.__name__) + @inherit_doc class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed, @@ -341,6 +373,13 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol True >>> rows[2].prediction == rows[3].prediction True + >>> model.hasSummary + True + >>> summary = model.summary + >>> summary.k + 2 + >>> summary.clusterSizes + [2, 2] >>> kmeans_path = temp_path + "/kmeans" >>> kmeans.save(kmeans_path) >>> kmeans2 = KMeans.load(kmeans_path) @@ -349,6 +388,8 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol >>> model_path = temp_path + "/kmeans_model" >>> model.save(model_path) >>> model2 = KMeansModel.load(model_path) + >>> model2.hasSummary + False >>> model.clusterCenters()[0] == model2.clusterCenters()[0] array([ True, True], dtype=bool) >>> model.clusterCenters()[1] == model2.clusterCenters()[1] diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index c0f0d4073564..a0c288a0b71a 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1129,6 +1129,21 @@ def test_bisecting_kmeans_summary(self): self.assertEqual(len(s.clusterSizes), 2) self.assertEqual(s.k, 2) + def test_kmeans_summary(self): + data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),), + (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)] + df = self.spark.createDataFrame(data, ["features"]) + kmeans = KMeans(k=2, seed=1) + model = kmeans.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.featuresCol, "features") + self.assertEqual(s.predictionCol, "prediction") + self.assertTrue(isinstance(s.cluster, DataFrame)) + self.assertEqual(len(s.clusterSizes), 2) + self.assertEqual(s.k, 2) + class OneVsRestTests(SparkSessionTestCase):