@@ -1245,10 +1245,14 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
12451245 predictor (link function) and a description of the error distribution (family). It supports
12461246 "gaussian", "binomial", "poisson" and "gamma" as family. Valid link functions for each family
12471247 is listed below. The first link function of each family is the default one.
1248- - "gaussian" -> "identity", "log", "inverse"
1249- - "binomial" -> "logit", "probit", "cloglog"
1250- - "poisson" -> "log", "identity", "sqrt"
1251- - "gamma" -> "inverse", "identity", "log"
1248+
1249+ * "gaussian" -> "identity", "log", "inverse"
1250+
1251+ * "binomial" -> "logit", "probit", "cloglog"
1252+
1253+ * "poisson" -> "log", "identity", "sqrt"
1254+
1255+ * "gamma" -> "inverse", "identity", "log"
12521256
12531257 .. seealso:: `GLM <https://en.wikipedia.org/wiki/Generalized_linear_model>`_
12541258
@@ -1258,9 +1262,14 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
12581262 ... (1.0, Vectors.dense(1.0, 2.0)),
12591263 ... (2.0, Vectors.dense(0.0, 0.0)),
12601264 ... (2.0, Vectors.dense(1.0, 1.0)),], ["label", "features"])
1261- >>> glr = GeneralizedLinearRegression(family="gaussian", link="identity")
1265+ >>> glr = GeneralizedLinearRegression(family="gaussian", link="identity", linkPredictionCol="p")
1266+ >>> print(glr.getLinkPredictionCol())
1267+ p
12621268 >>> model = glr.fit(df)
1263- >>> abs(model.transform(df).head().prediction - 1.5) < 0.001
1269+ >>> transformed = model.transform(df)
1270+ >>> abs(transformed.head().prediction - 1.5) < 0.001
1271+ True
1272+ >>> abs(transformed.head().p - 1.5) < 0.001
12641273 True
12651274 >>> model.coefficients
12661275 DenseVector([1.5..., -1.0...])
@@ -1290,32 +1299,35 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
12901299 "relationship between the linear predictor and the mean of the distribution " +
12911300 "function. Supported options: identity, log, inverse, logit, probit, cloglog " +
12921301 "and sqrt." , typeConverter = TypeConverters .toString )
1302+ linkPredictionCol = Param (Params ._dummy (), "linkPredictionCol" , "link prediction (linear " +
1303+ "predictor) column name" , typeConverter = TypeConverters .toString )
12931304
12941305 @keyword_only
12951306 def __init__ (self , labelCol = "label" , featuresCol = "features" , predictionCol = "prediction" ,
12961307 family = "gaussian" , link = None , fitIntercept = True , maxIter = 25 , tol = 1e-6 ,
1297- regParam = 0.0 , weightCol = None , solver = "irls" ):
1308+ regParam = 0.0 , weightCol = None , solver = "irls" , linkPredictionCol = "" ):
12981309 """
12991310 __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
13001311 family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
1301- regParam=0.0, weightCol=None, solver="irls")
1312+ regParam=0.0, weightCol=None, solver="irls", linkPredictionCol="" )
13021313 """
13031314 super (GeneralizedLinearRegression , self ).__init__ ()
13041315 self ._java_obj = self ._new_java_obj (
13051316 "org.apache.spark.ml.regression.GeneralizedLinearRegression" , self .uid )
1306- self ._setDefault (family = "gaussian" , maxIter = 25 , tol = 1e-6 , regParam = 0.0 , solver = "irls" )
1317+ self ._setDefault (family = "gaussian" , maxIter = 25 , tol = 1e-6 , regParam = 0.0 , solver = "irls" ,
1318+ linkPredictionCol = "" )
13071319 kwargs = self .__init__ ._input_kwargs
13081320 self .setParams (** kwargs )
13091321
13101322 @keyword_only
13111323 @since ("2.0.0" )
13121324 def setParams (self , labelCol = "label" , featuresCol = "features" , predictionCol = "prediction" ,
13131325 family = "gaussian" , link = None , fitIntercept = True , maxIter = 25 , tol = 1e-6 ,
1314- regParam = 0.0 , weightCol = None , solver = "irls" ):
1326+ regParam = 0.0 , weightCol = None , solver = "irls" , linkPredictionCol = "" ):
13151327 """
13161328 setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
13171329 family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
1318- regParam=0.0, weightCol=None, solver="irls")
1330+ regParam=0.0, weightCol=None, solver="irls", linkPredictionCol="" )
13191331 Sets params for generalized linear regression.
13201332 """
13211333 kwargs = self .setParams ._input_kwargs
@@ -1338,6 +1350,20 @@ def getFamily(self):
13381350 """
13391351 return self .getOrDefault (self .family )
13401352
1353+ @since ("2.0.0" )
1354+ def setLinkPredictionCol (self , value ):
1355+ """
1356+ Sets the value of :py:attr:`linkPredictionCol`.
1357+ """
1358+ return self ._set (linkPredictionCol = value )
1359+
1360+ @since ("2.0.0" )
1361+ def getLinkPredictionCol (self ):
1362+ """
1363+ Gets the value of linkPredictionCol or its default value.
1364+ """
1365+ return self .getOrDefault (self .linkPredictionCol )
1366+
13411367 @since ("2.0.0" )
13421368 def setLink (self , value ):
13431369 """
0 commit comments