diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 75d9a0e8cac1..8dbabc6d502a 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -17,7 +17,7 @@ from pyspark import since, keyword_only from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper from pyspark.ml.param.shared import * from pyspark.ml.common import inherit_doc @@ -56,8 +56,83 @@ def gaussiansDF(self): """ return self._call_java("gaussiansDF") + @property + @since("2.0.0") + def summary(self): + """ + Gets summary of model on training set. An exception is thrown if + `trainingSummary is None`. + """ + java_gmt_summary = self._call_java("summary") + return GaussianMixtureSummary(java_gmt_summary) + + @property + @since("2.0.0") + def hasSummary(self): + """ + Indicates whether a training summary exists for this model + instance. + """ + return self._call_java("hasSummary") + + +class GaussianMixtureSummary(JavaWrapper): + """ + Abstraction for Gaussian Mixture Results for a given model. + + .. versionadded:: 2.0.0 + """ + + @property + @since("2.0.0") + def predictions(self): + """ + Dataframe outputted by the model's `transform` method. + """ + return self._call_java("predictions") + + @property + @since("2.0.0") + def probabilityCol(self): + """ + Field in "predictions" which gives the probability + of each class. + """ + return self._call_java("probabilityCol") + + @property + @since("2.0.0") + def featuresCol(self): + """ + Field in "predictions" which gives the features of each instance. + """ + return self._call_java("featuresCol") + + @property + @since("2.0.0") + def cluster(self): + """ + Cluster centers of the transformed data. + """ + return self._call_java("cluster") + + @property + @since("2.0.0") + def probability(self): + """ + Probability of each cluster. + """ + return self._call_java("probability") + + @property + @since("2.0.0") + def clusterSizes(self): + """ + Size of (number of data points in) each cluster. + """ + return self._call_java("clusterSizes") + -@inherit_doc class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed, HasProbabilityCol, JavaMLWritable, JavaMLReadable): """ diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 981ed9dda042..ebfc13ac740e 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1070,6 +1070,21 @@ def test_logistic_regression_summary(self): sameSummary = model.evaluate(df) self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC) + def test_gaussian_mixture_summary(self): + from pyspark.mllib.linalg import Vectors + df = self.spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt") + gm = GaussianMixture(k=3, tol=0.0001, maxIter=10, seed=10) + model = gm.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + # test that api is callable and returns expected types + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.featuresCol, "features") + cluster_sizes = s.clusterSizes + self.assertTrue(isinstance(s.cluster, DataFrame)) + self.assertTrue(isinstance(s.probability, DataFrame)) + self.assertTrue(isinstance(cluster_sizes[0], int)) + class OneVsRestTests(SparkSessionTestCase):