Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion R/pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ exportMethods("glm",
"spark.isoreg",
"spark.gaussianMixture",
"spark.als",
"spark.kstest")
"spark.kstest",
"spark.glmnet")

# Job group lifecycle management methods
export("setJobGroup",
Expand Down
5 changes: 5 additions & 0 deletions R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -1304,6 +1304,11 @@ setGeneric("year", function(x) { standardGeneric("year") })
#' @export
setGeneric("spark.glm", function(data, formula, ...) { standardGeneric("spark.glm") })

#' @rdname spark.glmnet
#' @export
setGeneric("spark.glmnet", function(data, formula, ...) { standardGeneric("spark.glmnet") })


#' @param x,y For \code{glm}: logical values indicating whether the response vector
#' and model matrix used in the fitting process should be returned as
#' components of the returned value.
Expand Down
75 changes: 73 additions & 2 deletions R/pkg/R/mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@
#' @note GeneralizedLinearRegressionModel since 2.0.0
setClass("GeneralizedLinearRegressionModel", representation(jobj = "jobj"))

#' S4 class that represents a MultinomialLogisticRegressionModel
#'
#' @param jobj a Java object reference to the backing Scala MultinomialLogisticRegressionModel
#' @export
#' @note GeneralizedLinearRegressionModel since 2.1.0
setClass("MultinomialLogisticRegressionModel", representation(jobj = "jobj"))

#' S4 class that represents a NaiveBayesModel
#'
#' @param jobj a Java object reference to the backing Scala NaiveBayesWrapper
Expand Down Expand Up @@ -102,7 +109,7 @@ setClass("KSTest", representation(jobj = "jobj"))
#' @rdname write.ml
#' @name write.ml
#' @export
#' @seealso \link{spark.glm}, \link{glm},
#' @seealso \link{spark.glm}, \link{glm}, \link{spark.glmnet},
#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans},
#' @seealso \link{spark.lda}, \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg}
#' @seealso \link{read.ml}
Expand All @@ -115,7 +122,7 @@ NULL
#' @rdname predict
#' @name predict
#' @export
#' @seealso \link{spark.glm}, \link{glm},
#' @seealso \link{spark.glm}, \link{glm}, \link{spark.glmnet},
#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans},
#' @seealso \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg}
NULL
Expand Down Expand Up @@ -320,6 +327,54 @@ setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"),
predict_internal(object, newData)
})

setMethod("spark.glmnet", signature(data = "SparkDataFrame", formula = "formula"),
function(data, formula, regParam = 0.0, elasticNetParam = 0.0, tol = 1e-6, maxIter = 100,
fitIntercept = TRUE, standardization = TRUE, thresholds = NULL, weightCol = NULL) {

formula <- paste0(deparse(formula), collapse = "")
if (is.null(weightCol)) {
weightCol <- ""
}

jobj <- callJStatic("org.apache.spark.ml.r.MultinomialLogisticRegressionWrapper",
"fit", formula, data@sdf, as.numeric(regParam), as.numeric(elasticNetParam),
tol, as.integer(maxIter), as.logical(fitIntercept),
as.logical(standardization), as.array(thresholds), as.character(weightCol))
new("MultinomialLogisticRegressionModel", jobj = jobj)
})

# Predicted values based on a MultinomialLogisticRegression model

#' @param object a fitted MultinomialLogisticRegressionModel
#' @param newData SparkDataFrame for testing
#' @return \code{predict} returns a SparkDataFrame containing predicted values
#' @rdname spark.glmnet
#' @aliases predict,MultinomialLogisticRegressionModel,SparkDataFrame-method
#' @export
#' @note predict(MultinomialLogisticRegressionModel) since 2.1.0
setMethod("predict", signature(object = "MultinomialLogisticRegressionModel"),
function(object, newData) {
predict_internal(object, newData)
})

# Get the summary of a MultinomialLogisticRegression model

#' @return \code{summary} returns the model's coefficients, intercepts and numClasses
#' @rdname spark.glmnet
#' @aliases summary,MultinomialLogisticRegression-method
#' @export
#' @note summary(MultinomialLogisticRegressionModel) since 2.1.0
setMethod("summary", signature(object = "MultinomialLogisticRegressionModel"),
function(object) {
jobj <- object@jobj
coefficients <- callJMethod(jobj, "coefficients")
intercepts <- callJMethod(jobj, "intercepts")
numClasses <- callJMethod(jobj, "numClasses")
k <- callJMethod(jobj, "numFeatures")
coefficients <- t(matrix(coefficients, ncol = k))
list(coefficients = coefficients, intercepts = intercepts, numClasses = numClasses)
})

# Makes predictions from a naive Bayes model or a model produced by spark.naiveBayes(),
# similarly to R package e1071's predict.

Expand Down Expand Up @@ -826,6 +881,20 @@ setMethod("write.ml", signature(object = "GeneralizedLinearRegressionModel", pat
write_internal(object, path, overwrite)
})

# Saves the multinomial logistic regressionModel to the input path.

#' @param path the directory where the model is saved.
#' @param overwrite overwrites or not if the output path already exists. Default is FALSE
#' which means throw exception if the output path exists.
#'
#' @rdname spark.glmnet
#' @export
#' @note write.ml(MultinomialLogisticRegressionModel, character) since 2.1.0
setMethod("write.ml", signature(object = "MultinomialLogisticRegressionModel", path = "character"),
function(object, path, overwrite = FALSE) {
write_internal(object, path, overwrite)
})

# Save fitted MLlib model to the input path

#' @param path the directory where the model is saved.
Expand Down Expand Up @@ -922,6 +991,8 @@ read.ml <- function(path) {
new("GaussianMixtureModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.ALSWrapper")) {
new("ALSModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.MultinomialLogisticRegressionWrapper")) {
new("MultinomialLogisticRegressionModel", jobj = jobj)
} else {
stop("Unsupported model: ", jobj)
}
Expand Down