-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-15767][ML][SparkR] Decision Tree wrapper in SparkR #17981
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -45,6 +45,20 @@ setClass("RandomForestRegressionModel", representation(jobj = "jobj")) | |
| #' @note RandomForestClassificationModel since 2.1.0 | ||
| setClass("RandomForestClassificationModel", representation(jobj = "jobj")) | ||
|
|
||
| #' S4 class that represents a DecisionTreeRegressionModel | ||
| #' | ||
| #' @param jobj a Java object reference to the backing Scala DecisionTreeRegressionModel | ||
| #' @export | ||
| #' @note DecisionTreeRegressionModel since 2.3.0 | ||
| setClass("DecisionTreeRegressionModel", representation(jobj = "jobj")) | ||
|
|
||
| #' S4 class that represents a DecisionTreeClassificationModel | ||
| #' | ||
| #' @param jobj a Java object reference to the backing Scala DecisionTreeClassificationModel | ||
| #' @export | ||
| #' @note DecisionTreeClassificationModel since 2.3.0 | ||
| setClass("DecisionTreeClassificationModel", representation(jobj = "jobj")) | ||
|
|
||
| # Create the summary of a tree ensemble model (eg. Random Forest, GBT) | ||
| summary.treeEnsemble <- function(model) { | ||
| jobj <- model@jobj | ||
|
|
@@ -81,6 +95,36 @@ print.summary.treeEnsemble <- function(x) { | |
| invisible(x) | ||
| } | ||
|
|
||
| # Create the summary of a decision tree model | ||
| summary.decisionTree <- function(model) { | ||
| jobj <- model@jobj | ||
| formula <- callJMethod(jobj, "formula") | ||
| numFeatures <- callJMethod(jobj, "numFeatures") | ||
| features <- callJMethod(jobj, "features") | ||
| featureImportances <- callJMethod(callJMethod(jobj, "featureImportances"), "toString") | ||
| maxDepth <- callJMethod(jobj, "maxDepth") | ||
| list(formula = formula, | ||
| numFeatures = numFeatures, | ||
| features = features, | ||
| featureImportances = featureImportances, | ||
| maxDepth = maxDepth, | ||
| jobj = jobj) | ||
| } | ||
|
|
||
| # Prints the summary of decision tree models | ||
| print.summary.decisionTree <- function(x) { | ||
| jobj <- x$jobj | ||
| cat("Formula: ", x$formula) | ||
| cat("\nNumber of features: ", x$numFeatures) | ||
| cat("\nFeatures: ", unlist(x$features)) | ||
| cat("\nFeature importances: ", x$featureImportances) | ||
| cat("\nMax Depth: ", x$maxDepth) | ||
|
|
||
| summaryStr <- callJMethod(jobj, "summary") | ||
| cat("\n", summaryStr, "\n") | ||
| invisible(x) | ||
| } | ||
|
|
||
| #' Gradient Boosted Tree Model for Regression and Classification | ||
| #' | ||
| #' \code{spark.gbt} fits a Gradient Boosted Tree Regression model or Classification model on a | ||
|
|
@@ -499,3 +543,199 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path | |
| function(object, path, overwrite = FALSE) { | ||
| write_internal(object, path, overwrite) | ||
| }) | ||
|
|
||
| #' Decision Tree Model for Regression and Classification | ||
| #' | ||
| #' \code{spark.decisionTree} fits a Decision Tree Regression model or Classification model on | ||
| #' a SparkDataFrame. Users can call \code{summary} to get a summary of the fitted Decision Tree | ||
| #' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to | ||
| #' save/load fitted models. | ||
| #' For more details, see | ||
| #' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-regression}{ | ||
| #' Decision Tree Regression} and | ||
| #' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-classifier}{ | ||
| #' Decision Tree Classification} | ||
| #' | ||
| #' @param data a SparkDataFrame for training. | ||
| #' @param formula a symbolic description of the model to be fitted. Currently only a few formula | ||
| #' operators are supported, including '~', ':', '+', and '-'. | ||
| #' @param type type of model, one of "regression" or "classification", to fit | ||
| #' @param maxDepth Maximum depth of the tree (>= 0). | ||
| #' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing | ||
| #' how to split on features at each node. More bins give higher granularity. Must be | ||
| #' >= 2 and >= number of categories in any categorical feature. | ||
| #' @param impurity Criterion used for information gain calculation. | ||
| #' For regression, must be "variance". For classification, must be one of | ||
| #' "entropy" and "gini", default is "gini". | ||
| #' @param seed integer seed for random number generation. | ||
| #' @param minInstancesPerNode Minimum number of instances each child must have after split. | ||
| #' @param minInfoGain Minimum information gain for a split to be considered at a tree node. | ||
| #' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). | ||
| #' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. | ||
| #' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with | ||
| #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching | ||
| #' can speed up training of deeper trees. Users can set how often should the | ||
| #' cache be checkpointed or disable it by setting checkpointInterval. | ||
| #' @param ... additional arguments passed to the method. | ||
| #' @aliases spark.decisionTree,SparkDataFrame,formula-method | ||
| #' @return \code{spark.decisionTree} returns a fitted Decision Tree model. | ||
| #' @rdname spark.decisionTree | ||
| #' @name spark.decisionTree | ||
| #' @export | ||
| #' @examples | ||
| #' \dontrun{ | ||
| #' # fit a Decision Tree Regression Model | ||
| #' df <- createDataFrame(longley) | ||
| #' model <- spark.decisionTree(df, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16) | ||
| #' | ||
| #' # get the summary of the model | ||
| #' summary(model) | ||
| #' | ||
| #' # make predictions | ||
| #' predictions <- predict(model, df) | ||
| #' | ||
| #' # save and load the model | ||
| #' path <- "path/to/model" | ||
| #' write.ml(model, path) | ||
| #' savedModel <- read.ml(path) | ||
| #' summary(savedModel) | ||
| #' | ||
| #' # fit a Decision Tree Classification Model | ||
| #' t <- as.data.frame(Titanic) | ||
| #' df <- createDataFrame(t) | ||
| #' model <- spark.decisionTree(df, Survived ~ Freq + Age, "classification") | ||
| #' } | ||
| #' @note spark.decisionTree since 2.3.0 | ||
| setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "formula"), | ||
| function(data, formula, type = c("regression", "classification"), | ||
| maxDepth = 5, maxBins = 32, impurity = NULL, seed = NULL, | ||
| minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10, | ||
| maxMemoryInMB = 256, cacheNodeIds = FALSE) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. consider adding |
||
| type <- match.arg(type) | ||
| formula <- paste(deparse(formula), collapse = "") | ||
| if (!is.null(seed)) { | ||
| seed <- as.character(as.integer(seed)) | ||
| } | ||
| switch(type, | ||
| regression = { | ||
| if (is.null(impurity)) impurity <- "variance" | ||
| impurity <- match.arg(impurity, "variance") | ||
| jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeRegressorWrapper", | ||
| "fit", data@sdf, formula, as.integer(maxDepth), | ||
| as.integer(maxBins), impurity, | ||
| as.integer(minInstancesPerNode), as.numeric(minInfoGain), | ||
| as.integer(checkpointInterval), seed, | ||
| as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) | ||
| new("DecisionTreeRegressionModel", jobj = jobj) | ||
| }, | ||
| classification = { | ||
| if (is.null(impurity)) impurity <- "gini" | ||
| impurity <- match.arg(impurity, c("gini", "entropy")) | ||
| jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeClassifierWrapper", | ||
| "fit", data@sdf, formula, as.integer(maxDepth), | ||
| as.integer(maxBins), impurity, | ||
| as.integer(minInstancesPerNode), as.numeric(minInfoGain), | ||
| as.integer(checkpointInterval), seed, | ||
| as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) | ||
| new("DecisionTreeClassificationModel", jobj = jobj) | ||
| } | ||
| ) | ||
| }) | ||
|
|
||
| # Get the summary of a Decision Tree Regression Model | ||
|
|
||
| #' @return \code{summary} returns summary information of the fitted model, which is a list. | ||
| #' The list of components includes \code{formula} (formula), | ||
| #' \code{numFeatures} (number of features), \code{features} (list of features), | ||
| #' \code{featureImportances} (feature importances), and \code{maxDepth} (max depth of trees). | ||
| #' @rdname spark.decisionTree | ||
| #' @aliases summary,DecisionTreeRegressionModel-method | ||
| #' @export | ||
| #' @note summary(DecisionTreeRegressionModel) since 2.3.0 | ||
| setMethod("summary", signature(object = "DecisionTreeRegressionModel"), | ||
| function(object) { | ||
| ans <- summary.decisionTree(object) | ||
| class(ans) <- "summary.DecisionTreeRegressionModel" | ||
| ans | ||
| }) | ||
|
|
||
| # Prints the summary of Decision Tree Regression Model | ||
|
|
||
| #' @param x summary object of Decision Tree regression model or classification model | ||
| #' returned by \code{summary}. | ||
| #' @rdname spark.decisionTree | ||
| #' @export | ||
| #' @note print.summary.DecisionTreeRegressionModel since 2.3.0 | ||
| print.summary.DecisionTreeRegressionModel <- function(x, ...) { | ||
| print.summary.decisionTree(x) | ||
| } | ||
|
|
||
| # Get the summary of a Decision Tree Classification Model | ||
|
|
||
| #' @rdname spark.decisionTree | ||
| #' @aliases summary,DecisionTreeClassificationModel-method | ||
| #' @export | ||
| #' @note summary(DecisionTreeClassificationModel) since 2.3.0 | ||
| setMethod("summary", signature(object = "DecisionTreeClassificationModel"), | ||
| function(object) { | ||
| ans <- summary.decisionTree(object) | ||
| class(ans) <- "summary.DecisionTreeClassificationModel" | ||
| ans | ||
| }) | ||
|
|
||
| # Prints the summary of Decision Tree Classification Model | ||
|
|
||
| #' @rdname spark.decisionTree | ||
| #' @export | ||
| #' @note print.summary.DecisionTreeClassificationModel since 2.3.0 | ||
| print.summary.DecisionTreeClassificationModel <- function(x, ...) { | ||
| print.summary.decisionTree(x) | ||
| } | ||
|
|
||
| # Makes predictions from a Decision Tree Regression model or Classification model | ||
|
|
||
| #' @param newData a SparkDataFrame for testing. | ||
| #' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named | ||
| #' "prediction". | ||
| #' @rdname spark.decisionTree | ||
| #' @aliases predict,DecisionTreeRegressionModel-method | ||
| #' @export | ||
| #' @note predict(DecisionTreeRegressionModel) since 2.3.0 | ||
| setMethod("predict", signature(object = "DecisionTreeRegressionModel"), | ||
| function(object, newData) { | ||
| predict_internal(object, newData) | ||
| }) | ||
|
|
||
| #' @rdname spark.decisionTree | ||
| #' @aliases predict,DecisionTreeClassificationModel-method | ||
| #' @export | ||
| #' @note predict(DecisionTreeClassificationModel) since 2.3.0 | ||
| setMethod("predict", signature(object = "DecisionTreeClassificationModel"), | ||
| function(object, newData) { | ||
| predict_internal(object, newData) | ||
| }) | ||
|
|
||
| # Save the Decision Tree Regression or Classification model to the input path. | ||
|
|
||
| #' @param object A fitted Decision Tree regression model or classification model. | ||
| #' @param path The directory where the model is saved. | ||
| #' @param overwrite Overwrites or not if the output path already exists. Default is FALSE | ||
| #' which means throw exception if the output path exists. | ||
| #' | ||
| #' @aliases write.ml,DecisionTreeRegressionModel,character-method | ||
| #' @rdname spark.decisionTree | ||
| #' @export | ||
| #' @note write.ml(DecisionTreeRegressionModel, character) since 2.3.0 | ||
| setMethod("write.ml", signature(object = "DecisionTreeRegressionModel", path = "character"), | ||
| function(object, path, overwrite = FALSE) { | ||
| write_internal(object, path, overwrite) | ||
| }) | ||
|
|
||
| #' @aliases write.ml,DecisionTreeClassificationModel,character-method | ||
| #' @rdname spark.decisionTree | ||
| #' @export | ||
| #' @note write.ml(DecisionTreeClassificationModel, character) since 2.3.0 | ||
| setMethod("write.ml", signature(object = "DecisionTreeClassificationModel", path = "character"), | ||
| function(object, path, overwrite = FALSE) { | ||
| write_internal(object, path, overwrite) | ||
| }) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is kind of confusing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wording can be improved a bit I guess but this matches the Scaladoc...