Skip to content

Commit 55aa4da

Browse files
actuaryzhangFelix Cheung
authored andcommitted
[SPARK-21622][ML][SPARKR] Support offset in SparkR GLM
## What changes were proposed in this pull request? Support offset in SparkR GLM apache#16699 Author: actuaryzhang <actuaryzhang10@gmail.com> Closes apache#18831 from actuaryzhang/sparkROffset.
1 parent 74b4784 commit 55aa4da

File tree

3 files changed

+29
-5
lines changed

3 files changed

+29
-5
lines changed

R/pkg/R/mllib_regression.R

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj"))
7676
#' "frequencyDesc", "frequencyAsc", "alphabetDesc", and "alphabetAsc".
7777
#' The default value is "frequencyDesc". When the ordering is set to
7878
#' "alphabetDesc", this drops the same category as R when encoding strings.
79+
#' @param offsetCol the offset column name. If this is not set or empty, we treat all instance offsets
80+
#' as 0.0. The feature specified as offset has a constant coefficient of 1.0.
7981
#' @param ... additional arguments passed to the method.
8082
#' @aliases spark.glm,SparkDataFrame,formula-method
8183
#' @return \code{spark.glm} returns a fitted generalized linear model.
@@ -127,7 +129,8 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
127129
function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25, weightCol = NULL,
128130
regParam = 0.0, var.power = 0.0, link.power = 1.0 - var.power,
129131
stringIndexerOrderType = c("frequencyDesc", "frequencyAsc",
130-
"alphabetDesc", "alphabetAsc")) {
132+
"alphabetDesc", "alphabetAsc"),
133+
offsetCol = NULL) {
131134

132135
stringIndexerOrderType <- match.arg(stringIndexerOrderType)
133136
if (is.character(family)) {
@@ -159,12 +162,19 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
159162
weightCol <- as.character(weightCol)
160163
}
161164

165+
if (!is.null(offsetCol)) {
166+
offsetCol <- as.character(offsetCol)
167+
if (nchar(offsetCol) == 0) {
168+
offsetCol <- NULL
169+
}
170+
}
171+
162172
# For known families, Gamma is upper-cased
163173
jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
164174
"fit", formula, data@sdf, tolower(family$family), family$link,
165175
tol, as.integer(maxIter), weightCol, regParam,
166176
as.double(var.power), as.double(link.power),
167-
stringIndexerOrderType)
177+
stringIndexerOrderType, offsetCol)
168178
new("GeneralizedLinearRegressionModel", jobj = jobj)
169179
})
170180

@@ -192,6 +202,8 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
192202
#' "frequencyDesc", "frequencyAsc", "alphabetDesc", and "alphabetAsc".
193203
#' The default value is "frequencyDesc". When the ordering is set to
194204
#' "alphabetDesc", this drops the same category as R when encoding strings.
205+
#' @param offsetCol the offset column name. If this is not set or empty, we treat all instance offsets
206+
#' as 0.0. The feature specified as offset has a constant coefficient of 1.0.
195207
#' @return \code{glm} returns a fitted generalized linear model.
196208
#' @rdname glm
197209
#' @export
@@ -209,10 +221,12 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDat
209221
function(formula, family = gaussian, data, epsilon = 1e-6, maxit = 25, weightCol = NULL,
210222
var.power = 0.0, link.power = 1.0 - var.power,
211223
stringIndexerOrderType = c("frequencyDesc", "frequencyAsc",
212-
"alphabetDesc", "alphabetAsc")) {
224+
"alphabetDesc", "alphabetAsc"),
225+
offsetCol = NULL) {
213226
spark.glm(data, formula, family, tol = epsilon, maxIter = maxit, weightCol = weightCol,
214227
var.power = var.power, link.power = link.power,
215-
stringIndexerOrderType = stringIndexerOrderType)
228+
stringIndexerOrderType = stringIndexerOrderType,
229+
offsetCol = offsetCol)
216230
})
217231

218232
# Returns the summary of a model produced by glm() or spark.glm(), similarly to R's summary().

R/pkg/tests/fulltests/test_mllib_regression.R

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,14 @@ test_that("spark.glm summary", {
173173
expect_equal(stats$df.residual, rStats$df.residual)
174174
expect_equal(stats$aic, rStats$aic)
175175

176+
# Test spark.glm works with offset
177+
training <- suppressWarnings(createDataFrame(iris))
178+
stats <- summary(spark.glm(training, Sepal_Width ~ Sepal_Length + Species,
179+
family = poisson(), offsetCol = "Petal_Length"))
180+
rStats <- suppressWarnings(summary(glm(Sepal.Width ~ Sepal.Length + Species,
181+
data = iris, family = poisson(), offset = iris$Petal.Length)))
182+
expect_true(all(abs(rStats$coefficients - stats$coefficients) < 1e-3))
183+
176184
# Test summary works on base GLM models
177185
baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris)
178186
baseSummary <- summary(baseModel)

mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ private[r] object GeneralizedLinearRegressionWrapper
7777
regParam: Double,
7878
variancePower: Double,
7979
linkPower: Double,
80-
stringIndexerOrderType: String): GeneralizedLinearRegressionWrapper = {
80+
stringIndexerOrderType: String,
81+
offsetCol: String): GeneralizedLinearRegressionWrapper = {
8182
// scalastyle:on
8283
val rFormula = new RFormula().setFormula(formula)
8384
.setStringIndexerOrderType(stringIndexerOrderType)
@@ -99,6 +100,7 @@ private[r] object GeneralizedLinearRegressionWrapper
99100
glr.setLink(link)
100101
}
101102
if (weightCol != null) glr.setWeightCol(weightCol)
103+
if (offsetCol != null) glr.setOffsetCol(offsetCol)
102104

103105
val pipeline = new Pipeline()
104106
.setStages(Array(rFormulaModel, glr))

0 commit comments

Comments
 (0)