-
Notifications
You must be signed in to change notification settings - Fork 29.1k
[SPARK-17157][SPARKR]: Add multiclass logistic regression SparkR Wrapper #15365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1264b4c
e264d6d
b341d77
63a3ac2
c9e1000
0b54f46
e2ca496
558dc20
d0452ae
031cf9b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -95,6 +95,13 @@ setClass("ALSModel", representation(jobj = "jobj")) | |
| #' @note KSTest since 2.1.0 | ||
| setClass("KSTest", representation(jobj = "jobj")) | ||
|
|
||
| #' S4 class that represents an LogisticRegressionModel | ||
| #' | ||
| #' @param jobj a Java object reference to the backing Scala LogisticRegressionModel | ||
| #' @export | ||
| #' @note LogisticRegressionModel since 2.1.0 | ||
| setClass("LogisticRegressionModel", representation(jobj = "jobj")) | ||
|
|
||
| #' Saves the MLlib model to the input path | ||
| #' | ||
| #' Saves the MLlib model to the input path. For more information, see the specific | ||
|
|
@@ -105,7 +112,7 @@ setClass("KSTest", representation(jobj = "jobj")) | |
| #' @seealso \link{spark.glm}, \link{glm}, | ||
| #' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans}, | ||
| #' @seealso \link{spark.lda}, \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg} | ||
| #' @seealso \link{read.ml} | ||
| #' @seealso \link{spark.logit}, \link{read.ml} | ||
| NULL | ||
|
|
||
| #' Makes predictions from a MLlib model | ||
|
|
@@ -117,7 +124,7 @@ NULL | |
| #' @export | ||
| #' @seealso \link{spark.glm}, \link{glm}, | ||
| #' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans}, | ||
| #' @seealso \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg} | ||
| #' @seealso \link{spark.mlp}, \link{spark.naiveBayes}, \link{spark.survreg}, \link{spark.logit} | ||
|
||
| NULL | ||
|
|
||
| write_internal <- function(object, path, overwrite = FALSE) { | ||
|
|
@@ -647,6 +654,170 @@ setMethod("predict", signature(object = "KMeansModel"), | |
| predict_internal(object, newData) | ||
| }) | ||
|
|
||
| #' Logistic Regression Model | ||
| #' | ||
| #' Fits an logistic regression model against a Spark DataFrame. It supports "binomial": Binary logistic regression | ||
| #' with pivoting; "multinomial": Multinomial logistic (softmax) regression without pivoting, similar to glmnet. | ||
| #' Users can print, make predictions on the produced model and save the model to the input path. | ||
| #' | ||
| #' @param data SparkDataFrame for training | ||
| #' @param formula A symbolic description of the model to be fitted. Currently only a few formula | ||
| #' operators are supported, including '~', '.', ':', '+', and '-'. | ||
| #' @param regParam the regularization parameter. Default is 0.0. | ||
| #' @param elasticNetParam the ElasticNet mixing parameter. For alpha = 0.0, the penalty is an L2 penalty. | ||
| #' For alpha = 1.0, it is an L1 penalty. For 0.0 < alpha < 1.0, the penalty is a combination | ||
| #' of L1 and L2. Default is 0.0 which is an L2 penalty. | ||
| #' @param maxIter maximum iteration number. | ||
| #' @param tol convergence tolerance of iterations. | ||
| #' @param fitIntercept whether to fit an intercept term. Default is TRUE. | ||
| #' @param family the name of family which is a description of the label distribution to be used in the model. | ||
| #' Supported options: | ||
|
||
| #' \itemize{ | ||
| #' \item{"auto": Automatically select the family based on the number of classes: | ||
| #' If number of classes == 1 || number of classes == 2, set to "binomial". | ||
| #' Else, set to "multinomial".} | ||
| #' \item{"binomial": Binary logistic regression with pivoting.} | ||
| #' \item{"multinomial": Multinomial logistic (softmax) regression without pivoting. | ||
| #' Default is "auto".} | ||
| #' } | ||
| #' @param standardization whether to standardize the training features before fitting the model. The coefficients | ||
| #' of models will be always returned on the original scale, so it will be transparent for | ||
| #' users. Note that with/without standardization, the models should be always converged | ||
| #' to the same solution when no regularization is applied. Default is TRUE, same as glmnet. | ||
| #' @param thresholds in binary classification, in range [0, 1]. If the estimated probability of class label 1 | ||
| #' is > threshold, then predict 1, else 0. A high threshold encourages the model to predict 0 | ||
| #' more often; a low threshold encourages the model to predict 1 more often. Note: Setting this with | ||
| #' threshold p is equivalent to setting thresholds c(1-p, p). When threshold is set, any user-set | ||
| #' value for thresholds will be cleared. If both threshold and thresholds are set, then they must be | ||
| #' equivalent. In multiclass (or binary) classification to adjust the probability of | ||
| #' predicting each class. Array must have length equal to the number of classes, with values > 0, | ||
| #' excepting that at most one value may be 0. The class with largest value p/t is predicted, where p | ||
| #' is the original probability of that class and t is the class's threshold. Note: When thresholds | ||
| #' is set, any user-set value for threshold will be cleared. If both threshold and thresholds are | ||
| #' set, then they must be equivalent. Default is 0.5. | ||
| #' @param weightCol The weight column name. | ||
| #' @param aggregationDepth depth for treeAggregate (>= 2). If the dimensions of features or the number of partitions | ||
| #' are large, this param could be adjusted to a larger size. Default is 2. | ||
| #' @param probabilityCol column name for predicted class conditional probabilities. Default is "probability". | ||
| #' @param ... additional arguments passed to the method. | ||
| #' @return \code{spark.logit} returns a fitted logistic regression model | ||
| #' @rdname spark.logit | ||
| #' @aliases spark.logit,SparkDataFrame,formula-method | ||
| #' @name spark.logit | ||
| #' @export | ||
| #' @examples | ||
| #' \dontrun{ | ||
| #' sparkR.session() | ||
| #' # binary logistic regression | ||
| #' label <- c(1.0, 1.0, 1.0, 0.0, 0.0) | ||
| #' feature <- c(1.1419053, 0.9194079, -0.9498666, -1.1069903, 0.2809776) | ||
| #' binary_data <- as.data.frame(cbind(label, feature)) | ||
| #' binary_df <- createDataFrame(binary_data) | ||
| #' blr_model <- spark.logit(binary_df, label ~ feature, thresholds = 1.0) | ||
| #' blr_predict <- collect(select(predict(blr_model, binary_df), "prediction")) | ||
| #' | ||
| #' # summary of binary logistic regression | ||
| #' blr_summary <- summary(blr_model) | ||
| #' blr_fmeasure <- collect(select(blr_summary$fMeasureByThreshold, "threshold", "F-Measure")) | ||
| #' # save fitted model to input path | ||
| #' path <- "path/to/model" | ||
| #' write.ml(blr_model, path) | ||
| #' | ||
| #' # can also read back the saved model and predict | ||
| #' Note that summary deos not work on loaded model | ||
| #' savedModel <- read.ml(path) | ||
| #' blr_predict2 <- collect(select(predict(savedModel, binary_df), "prediction")) | ||
| #' | ||
| #' # multinomial logistic regression | ||
| #' | ||
| #' label <- c(0.0, 1.0, 2.0, 0.0, 0.0) | ||
| #' feature1 <- c(4.845940, 5.64480, 7.430381, 6.464263, 5.555667) | ||
| #' feature2 <- c(2.941319, 2.614812, 2.162451, 3.339474, 2.970987) | ||
| #' feature3 <- c(1.322733, 1.348044, 3.861237, 9.686976, 3.447130) | ||
| #' feature4 <- c(1.3246388, 0.5510444, 0.9225810, 1.2147881, 1.6020842) | ||
| #' data <- as.data.frame(cbind(label, feature1, feature2, feature3, feature4)) | ||
| #' df <- createDataFrame(data) | ||
| #' | ||
| #' Note that summary of multinomial logistic regression is not implemented yet | ||
| #' model <- spark.logit(df, label ~ ., family = "multinomial", thresholds=c(0, 1, 1)) | ||
| #' predict1 <- collect(select(predict(model, df), "prediction")) | ||
| #' } | ||
| #' @note spark.logit since 2.1.0 | ||
| setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula"), | ||
| function(data, formula, regParam = 0.0, elasticNetParam = 0.0, maxIter = 100, | ||
| tol = 1E-6, fitIntercept = TRUE, family = "auto", standardization = TRUE, | ||
| thresholds = 0.5, weightCol = NULL, aggregationDepth = 2, | ||
| probabilityCol = "probability") { | ||
| formula <- paste0(deparse(formula), collapse = "") | ||
|
|
||
| if (is.null(weightCol)) { | ||
| weightCol <- "" | ||
| } | ||
|
|
||
| jobj <- callJStatic("org.apache.spark.ml.r.LogisticRegressionWrapper", "fit", | ||
| data@sdf, formula, as.numeric(regParam), | ||
| as.numeric(elasticNetParam), as.integer(maxIter), | ||
| as.numeric(tol), as.logical(fitIntercept), | ||
| as.character(family), as.logical(standardization), | ||
| as.array(thresholds), as.character(weightCol), | ||
| as.integer(aggregationDepth), as.character(probabilityCol)) | ||
| new("LogisticRegressionModel", jobj = jobj) | ||
| }) | ||
|
|
||
| # Predicted values based on an LogisticRegressionModel model | ||
|
|
||
| #' @param newData a SparkDataFrame for testing. | ||
| #' @return \code{predict} returns the predicted values based on an LogisticRegressionModel. | ||
| #' @rdname spark.logit | ||
|
||
| #' @aliases predict,LogisticRegressionModel,SparkDataFrame-method | ||
| #' @export | ||
| #' @note predict(LogisticRegressionModel) since 2.1.0 | ||
| setMethod("predict", signature(object = "LogisticRegressionModel"), | ||
| function(object, newData) { | ||
| predict_internal(object, newData) | ||
| }) | ||
|
|
||
| # Get the summary of an LogisticRegressionModel | ||
|
|
||
| #' @param object an LogisticRegressionModel fitted by \code{spark.logit} | ||
| #' @return \code{summary} returns the Binary Logistic regression results of a given model as lists. Note that | ||
| #' Multinomial logistic regression summary is not available now. | ||
| #' @rdname spark.logit | ||
| #' @aliases summary,LogisticRegressionModel-method | ||
| #' @export | ||
| #' @note summary(LogisticRegressionModel) since 2.1.0 | ||
| setMethod("summary", signature(object = "LogisticRegressionModel"), | ||
| function(object) { | ||
| jobj <- object@jobj | ||
| is.loaded <- callJMethod(jobj, "isLoaded") | ||
|
|
||
| if (is.loaded) { | ||
| stop("Loaded model doesn't have training summary.") | ||
| } | ||
|
|
||
| roc <- dataFrame(callJMethod(jobj, "roc")) | ||
|
|
||
| areaUnderROC <- callJMethod(jobj, "areaUnderROC") | ||
|
|
||
| pr <- dataFrame(callJMethod(jobj, "pr")) | ||
|
|
||
| fMeasureByThreshold <- dataFrame(callJMethod(jobj, "fMeasureByThreshold")) | ||
|
|
||
| precisionByThreshold <- dataFrame(callJMethod(jobj, "precisionByThreshold")) | ||
|
|
||
| recallByThreshold <- dataFrame(callJMethod(jobj, "recallByThreshold")) | ||
|
|
||
| totalIterations <- callJMethod(jobj, "totalIterations") | ||
|
|
||
| objectiveHistory <- callJMethod(jobj, "objectiveHistory") | ||
|
|
||
| list(roc = roc, areaUnderROC = areaUnderROC, pr = pr, | ||
| fMeasureByThreshold = fMeasureByThreshold, | ||
| precisionByThreshold = precisionByThreshold, | ||
| recallByThreshold = recallByThreshold, | ||
| totalIterations = totalIterations, objectiveHistory = objectiveHistory) | ||
| }) | ||
|
|
||
| #' Multilayer Perceptron Classification Model | ||
| #' | ||
| #' \code{spark.mlp} fits a multi-layer perceptron neural network model against a SparkDataFrame. | ||
|
|
@@ -882,6 +1053,21 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char | |
| write_internal(object, path, overwrite) | ||
| }) | ||
|
|
||
| # Save fitted LogisticRegressionModel to the input path | ||
|
|
||
| #' @param path The directory where the model is saved | ||
| #' @param overwrite Overwrites or not if the output path already exists. Default is FALSE | ||
| #' which means throw exception if the output path exists. | ||
| #' | ||
| #' @rdname spark.logit | ||
| #' @aliases write.ml,LogisticRegressionModel,character-method | ||
| #' @export | ||
| #' @note write.ml(LogisticRegression, character) since 2.1.0 | ||
| setMethod("write.ml", signature(object = "LogisticRegressionModel", path = "character"), | ||
| function(object, path, overwrite = FALSE) { | ||
| write_internal(object, path, overwrite) | ||
| }) | ||
|
|
||
| # Save fitted MLlib model to the input path | ||
|
|
||
| #' @param path the directory where the model is saved. | ||
|
|
@@ -932,6 +1118,8 @@ read.ml <- function(path) { | |
| new("GaussianMixtureModel", jobj = jobj) | ||
| } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.ALSWrapper")) { | ||
| new("ALSModel", jobj = jobj) | ||
| } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LogisticRegressionWrapper")) { | ||
| new("LogisticRegressionModel", jobj = jobj) | ||
| } else { | ||
| stop("Unsupported model: ", jobj) | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this group of links could be sorted
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will make changes when we agree on the name. Thanks!