Skip to content

Commit f062750

Browse files
committed
Add linkPredictionCol to GeneralizedLinearRegression and fix the PyDoc to generate the bullet list
1 parent 3ded5bc commit f062750

File tree

1 file changed

+37
-11
lines changed

1 file changed

+37
-11
lines changed

python/pyspark/ml/regression.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,10 +1245,14 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
12451245
predictor (link function) and a description of the error distribution (family). It supports
12461246
"gaussian", "binomial", "poisson" and "gamma" as family. Valid link functions for each family
12471247
is listed below. The first link function of each family is the default one.
1248-
- "gaussian" -> "identity", "log", "inverse"
1249-
- "binomial" -> "logit", "probit", "cloglog"
1250-
- "poisson" -> "log", "identity", "sqrt"
1251-
- "gamma" -> "inverse", "identity", "log"
1248+
1249+
* "gaussian" -> "identity", "log", "inverse"
1250+
1251+
* "binomial" -> "logit", "probit", "cloglog"
1252+
1253+
* "poisson" -> "log", "identity", "sqrt"
1254+
1255+
* "gamma" -> "inverse", "identity", "log"
12521256
12531257
.. seealso:: `GLM <https://en.wikipedia.org/wiki/Generalized_linear_model>`_
12541258
@@ -1258,9 +1262,14 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
12581262
... (1.0, Vectors.dense(1.0, 2.0)),
12591263
... (2.0, Vectors.dense(0.0, 0.0)),
12601264
... (2.0, Vectors.dense(1.0, 1.0)),], ["label", "features"])
1261-
>>> glr = GeneralizedLinearRegression(family="gaussian", link="identity")
1265+
>>> glr = GeneralizedLinearRegression(family="gaussian", link="identity", linkPredictionCol="p")
1266+
>>> print(glr.getLinkPredictionCol())
1267+
p
12621268
>>> model = glr.fit(df)
1263-
>>> abs(model.transform(df).head().prediction - 1.5) < 0.001
1269+
>>> transformed = model.transform(df)
1270+
>>> abs(transformed.head().prediction - 1.5) < 0.001
1271+
True
1272+
>>> abs(transformed.head().p - 1.5) < 0.001
12641273
True
12651274
>>> model.coefficients
12661275
DenseVector([1.5..., -1.0...])
@@ -1290,32 +1299,35 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
12901299
"relationship between the linear predictor and the mean of the distribution " +
12911300
"function. Supported options: identity, log, inverse, logit, probit, cloglog " +
12921301
"and sqrt.", typeConverter=TypeConverters.toString)
1302+
linkPredictionCol = Param(Params._dummy(), "linkPredictionCol", "link prediction (linear " +
1303+
"predictor) column name", typeConverter=TypeConverters.toString)
12931304

12941305
@keyword_only
12951306
def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction",
12961307
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6,
1297-
regParam=0.0, weightCol=None, solver="irls"):
1308+
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=""):
12981309
"""
12991310
__init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
13001311
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
1301-
regParam=0.0, weightCol=None, solver="irls")
1312+
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol="")
13021313
"""
13031314
super(GeneralizedLinearRegression, self).__init__()
13041315
self._java_obj = self._new_java_obj(
13051316
"org.apache.spark.ml.regression.GeneralizedLinearRegression", self.uid)
1306-
self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls")
1317+
self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls",
1318+
linkPredictionCol="")
13071319
kwargs = self.__init__._input_kwargs
13081320
self.setParams(**kwargs)
13091321

13101322
@keyword_only
13111323
@since("2.0.0")
13121324
def setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction",
13131325
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6,
1314-
regParam=0.0, weightCol=None, solver="irls"):
1326+
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=""):
13151327
"""
13161328
setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
13171329
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
1318-
regParam=0.0, weightCol=None, solver="irls")
1330+
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol="")
13191331
Sets params for generalized linear regression.
13201332
"""
13211333
kwargs = self.setParams._input_kwargs
@@ -1338,6 +1350,20 @@ def getFamily(self):
13381350
"""
13391351
return self.getOrDefault(self.family)
13401352

1353+
@since("2.0.0")
1354+
def setLinkPredictionCol(self, value):
1355+
"""
1356+
Sets the value of :py:attr:`linkPredictionCol`.
1357+
"""
1358+
return self._set(linkPredictionCol=value)
1359+
1360+
@since("2.0.0")
1361+
def getLinkPredictionCol(self):
1362+
"""
1363+
Gets the value of linkPredictionCol or its default value.
1364+
"""
1365+
return self.getOrDefault(self.linkPredictionCol)
1366+
13411367
@since("2.0.0")
13421368
def setLink(self, value):
13431369
"""

0 commit comments

Comments
 (0)