Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,13 @@ class LogisticRegressionModel private[spark] (
*/
@Since("1.6.0")
override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this)

@Since("2.0.0")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will need to update this since :)

override def toString: String = {
val td = getDefault(threshold)
s"${super.toString}, numClasses = $numClasses, " +
s"numFeatures = $numFeatures threshold = ${td.getOrElse("None")}"
}
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,13 @@ class LogisticRegressionSuite
assert(expected.coefficients.toArray === actual.coefficients.toArray)
}
}

test("toString") {
val lrModel = new LogisticRegressionModel(uid = "lrModeltest",
coefficients = Vectors.dense(0.1, 0.2, 0.3), intercept = 0.5)
val expected: String = "lrModeltest, numClasses = 2, numFeatures = 3 threshold = 0.5"
assert(lrModel.toString === expected)
}
}

object LogisticRegressionSuite {
Expand Down
4 changes: 4 additions & 0 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,10 @@ def evaluate(self, dataset):
java_blr_summary = self._call_java("evaluate", dataset)
return BinaryLogisticRegressionSummary(java_blr_summary)

@since("2.0.0")
def __repr__(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might make sense to include this in the docstring since it might be useful for people to see how to use it :)

return self._call_java("toString")


class LogisticRegressionSummary(JavaWrapper):
"""
Expand Down
4 changes: 4 additions & 0 deletions python/pyspark/mllib/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,10 @@ def load(cls, sc, path):
model.setThreshold(threshold)
return model

@since("2.0.0")
def __repr__(self):
return self._call_java("toString")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Python style asks for two new lines here (try running ./dev/lint-python locally :))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok :)


class LogisticRegressionWithSGD(object):
"""
Expand Down