Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 35 additions & 11 deletions python/pyspark/ml/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1245,10 +1245,14 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
predictor (link function) and a description of the error distribution (family). It supports
"gaussian", "binomial", "poisson" and "gamma" as family. Valid link functions for each family
is listed below. The first link function of each family is the default one.
- "gaussian" -> "identity", "log", "inverse"
- "binomial" -> "logit", "probit", "cloglog"
- "poisson" -> "log", "identity", "sqrt"
- "gamma" -> "inverse", "identity", "log"

* "gaussian" -> "identity", "log", "inverse"

* "binomial" -> "logit", "probit", "cloglog"

* "poisson" -> "log", "identity", "sqrt"

* "gamma" -> "inverse", "identity", "log"

.. seealso:: `GLM <https://en.wikipedia.org/wiki/Generalized_linear_model>`_

Expand All @@ -1258,9 +1262,12 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
... (1.0, Vectors.dense(1.0, 2.0)),
... (2.0, Vectors.dense(0.0, 0.0)),
... (2.0, Vectors.dense(1.0, 1.0)),], ["label", "features"])
>>> glr = GeneralizedLinearRegression(family="gaussian", link="identity")
>>> glr = GeneralizedLinearRegression(family="gaussian", link="identity", linkPredictionCol="p")
>>> model = glr.fit(df)
>>> abs(model.transform(df).head().prediction - 1.5) < 0.001
>>> transformed = model.transform(df)
>>> abs(transformed.head().prediction - 1.5) < 0.001
True
>>> abs(transformed.head().p - 1.5) < 0.001
True
>>> model.coefficients
DenseVector([1.5..., -1.0...])
Expand Down Expand Up @@ -1290,32 +1297,35 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
"relationship between the linear predictor and the mean of the distribution " +
"function. Supported options: identity, log, inverse, logit, probit, cloglog " +
"and sqrt.", typeConverter=TypeConverters.toString)
linkPredictionCol = Param(Params._dummy(), "linkPredictionCol", "link prediction (linear " +
"predictor) column name", typeConverter=TypeConverters.toString)

@keyword_only
def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction",
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6,
regParam=0.0, weightCol=None, solver="irls"):
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=""):
"""
__init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
regParam=0.0, weightCol=None, solver="irls")
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol="")
"""
super(GeneralizedLinearRegression, self).__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.regression.GeneralizedLinearRegression", self.uid)
self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls")
self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls",
linkPredictionCol="")
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)

@keyword_only
@since("2.0.0")
def setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction",
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6,
regParam=0.0, weightCol=None, solver="irls"):
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=""):
"""
setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
regParam=0.0, weightCol=None, solver="irls")
regParam=0.0, weightCol=None, solver="irls", linkPredictionCol="")
Sets params for generalized linear regression.
"""
kwargs = self.setParams._input_kwargs
Expand All @@ -1338,6 +1348,20 @@ def getFamily(self):
"""
return self.getOrDefault(self.family)

@since("2.0.0")
def setLinkPredictionCol(self, value):
"""
Sets the value of :py:attr:`linkPredictionCol`.
"""
return self._set(linkPredictionCol=value)

@since("2.0.0")
def getLinkPredictionCol(self):
"""
Gets the value of linkPredictionCol or its default value.
"""
return self.getOrDefault(self.linkPredictionCol)

@since("2.0.0")
def setLink(self, value):
"""
Expand Down