diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index c2b440059b1f..4466d68210e5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -625,6 +625,13 @@ class LogisticRegressionModel private[spark] ( */ @Since("1.6.0") override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this) + + @Since("2.0.0") + override def toString: String = { + val td = getDefault(threshold) + s"${super.toString}, numClasses = $numClasses, " + + s"numFeatures = $numFeatures threshold = ${td.getOrElse("None")}" + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 48db4281309b..d8eb16bec624 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -943,6 +943,13 @@ class LogisticRegressionSuite assert(expected.coefficients.toArray === actual.coefficients.toArray) } } + + test("toString") { + val lrModel = new LogisticRegressionModel(uid = "lrModeltest", + coefficients = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5) + val expected: String = "lrModeltest, numClasses = 2, numFeatures = 3 threshold = 0.5" + assert(lrModel.toString === expected) + } } object LogisticRegressionSuite { diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index de1321b13975..c21a6c9000f4 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -273,6 +273,10 @@ def evaluate(self, dataset): java_blr_summary = self._call_java("evaluate", dataset) return BinaryLogisticRegressionSummary(java_blr_summary) + @since("2.0.0") + def __repr__(self): + return self._call_java("toString") + class LogisticRegressionSummary(JavaWrapper): """ diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 57106f8690a7..c0cda2e5c4e4 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -262,6 +262,10 @@ def load(cls, sc, path): model.setThreshold(threshold) return model + @since("2.0.0") + def __repr__(self): + return self._call_java("toString") + class LogisticRegressionWithSGD(object): """