Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions R/pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ exportMethods("glm",
"spark.als",
"spark.kstest",
"spark.logit",
"spark.decisionTree",
"spark.randomForest",
"spark.gbt",
"spark.bisectingKmeans",
Expand Down Expand Up @@ -414,6 +415,8 @@ export("as.DataFrame",
"print.summary.GeneralizedLinearRegressionModel",
"read.ml",
"print.summary.KSTest",
"print.summary.DecisionTreeRegressionModel",
"print.summary.DecisionTreeClassificationModel",
"print.summary.RandomForestRegressionModel",
"print.summary.RandomForestClassificationModel",
"print.summary.GBTRegressionModel",
Expand Down Expand Up @@ -452,6 +455,8 @@ S3method(print, structField)
S3method(print, structType)
S3method(print, summary.GeneralizedLinearRegressionModel)
S3method(print, summary.KSTest)
S3method(print, summary.DecisionTreeRegressionModel)
S3method(print, summary.DecisionTreeClassificationModel)
S3method(print, summary.RandomForestRegressionModel)
S3method(print, summary.RandomForestClassificationModel)
S3method(print, summary.GBTRegressionModel)
Expand Down
5 changes: 5 additions & 0 deletions R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -1506,6 +1506,11 @@ setGeneric("spark.mlp", function(data, formula, ...) { standardGeneric("spark.ml
#' @export
setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("spark.naiveBayes") })

#' @rdname spark.decisionTree
#' @export
setGeneric("spark.decisionTree",
function(data, formula, ...) { standardGeneric("spark.decisionTree") })

#' @rdname spark.randomForest
#' @export
setGeneric("spark.randomForest",
Expand Down
240 changes: 240 additions & 0 deletions R/pkg/R/mllib_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,20 @@ setClass("RandomForestRegressionModel", representation(jobj = "jobj"))
#' @note RandomForestClassificationModel since 2.1.0
setClass("RandomForestClassificationModel", representation(jobj = "jobj"))

#' S4 class that represents a DecisionTreeRegressionModel
#'
#' @param jobj a Java object reference to the backing Scala DecisionTreeRegressionModel
#' @export
#' @note DecisionTreeRegressionModel since 2.3.0
setClass("DecisionTreeRegressionModel", representation(jobj = "jobj"))

#' S4 class that represents a DecisionTreeClassificationModel
#'
#' @param jobj a Java object reference to the backing Scala DecisionTreeClassificationModel
#' @export
#' @note DecisionTreeClassificationModel since 2.3.0
setClass("DecisionTreeClassificationModel", representation(jobj = "jobj"))

# Create the summary of a tree ensemble model (eg. Random Forest, GBT)
summary.treeEnsemble <- function(model) {
jobj <- model@jobj
Expand Down Expand Up @@ -81,6 +95,36 @@ print.summary.treeEnsemble <- function(x) {
invisible(x)
}

# Create the summary of a decision tree model
summary.decisionTree <- function(model) {
jobj <- model@jobj
formula <- callJMethod(jobj, "formula")
numFeatures <- callJMethod(jobj, "numFeatures")
features <- callJMethod(jobj, "features")
featureImportances <- callJMethod(callJMethod(jobj, "featureImportances"), "toString")
maxDepth <- callJMethod(jobj, "maxDepth")
list(formula = formula,
numFeatures = numFeatures,
features = features,
featureImportances = featureImportances,
maxDepth = maxDepth,
jobj = jobj)
}

# Prints the summary of decision tree models
print.summary.decisionTree <- function(x) {
jobj <- x$jobj
cat("Formula: ", x$formula)
cat("\nNumber of features: ", x$numFeatures)
cat("\nFeatures: ", unlist(x$features))
cat("\nFeature importances: ", x$featureImportances)
cat("\nMax Depth: ", x$maxDepth)

summaryStr <- callJMethod(jobj, "summary")
cat("\n", summaryStr, "\n")
invisible(x)
}

#' Gradient Boosted Tree Model for Regression and Classification
#'
#' \code{spark.gbt} fits a Gradient Boosted Tree Regression model or Classification model on a
Expand Down Expand Up @@ -499,3 +543,199 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path
function(object, path, overwrite = FALSE) {
write_internal(object, path, overwrite)
})

#' Decision Tree Model for Regression and Classification
#'
#' \code{spark.decisionTree} fits a Decision Tree Regression model or Classification model on
#' a SparkDataFrame. Users can call \code{summary} to get a summary of the fitted Decision Tree
#' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to
#' save/load fitted models.
#' For more details, see
#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-regression}{
#' Decision Tree Regression} and
#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-classifier}{
#' Decision Tree Classification}
#'
#' @param data a SparkDataFrame for training.
#' @param formula a symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~', ':', '+', and '-'.
#' @param type type of model, one of "regression" or "classification", to fit
#' @param maxDepth Maximum depth of the tree (>= 0).
#' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing
#' how to split on features at each node. More bins give higher granularity. Must be
#' >= 2 and >= number of categories in any categorical feature.
#' @param impurity Criterion used for information gain calculation.
#' For regression, must be "variance". For classification, must be one of
#' "entropy" and "gini", default is "gini".
#' @param seed integer seed for random number generation.
#' @param minInstancesPerNode Minimum number of instances each child must have after split.
#' @param minInfoGain Minimum information gain for a split to be considered at a tree node.
#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1).
#' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation.
#' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with
#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching
#' can speed up training of deeper trees. Users can set how often should the
#' cache be checkpointed or disable it by setting checkpointInterval.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is kind of confusing

Users can set how often should the cache be checkpointed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wording can be improved a bit I guess but this matches the Scaladoc...

#' @param ... additional arguments passed to the method.
#' @aliases spark.decisionTree,SparkDataFrame,formula-method
#' @return \code{spark.decisionTree} returns a fitted Decision Tree model.
#' @rdname spark.decisionTree
#' @name spark.decisionTree
#' @export
#' @examples
#' \dontrun{
#' # fit a Decision Tree Regression Model
#' df <- createDataFrame(longley)
#' model <- spark.decisionTree(df, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16)
#'
#' # get the summary of the model
#' summary(model)
#'
#' # make predictions
#' predictions <- predict(model, df)
#'
#' # save and load the model
#' path <- "path/to/model"
#' write.ml(model, path)
#' savedModel <- read.ml(path)
#' summary(savedModel)
#'
#' # fit a Decision Tree Classification Model
#' t <- as.data.frame(Titanic)
#' df <- createDataFrame(t)
#' model <- spark.decisionTree(df, Survived ~ Freq + Age, "classification")
#' }
#' @note spark.decisionTree since 2.3.0
setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "formula"),
function(data, formula, type = c("regression", "classification"),
maxDepth = 5, maxBins = 32, impurity = NULL, seed = NULL,
minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10,
maxMemoryInMB = 256, cacheNodeIds = FALSE) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider adding thresholds parameter - possibly as a follow up PR.

type <- match.arg(type)
formula <- paste(deparse(formula), collapse = "")
if (!is.null(seed)) {
seed <- as.character(as.integer(seed))
}
switch(type,
regression = {
if (is.null(impurity)) impurity <- "variance"
impurity <- match.arg(impurity, "variance")
jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeRegressorWrapper",
"fit", data@sdf, formula, as.integer(maxDepth),
as.integer(maxBins), impurity,
as.integer(minInstancesPerNode), as.numeric(minInfoGain),
as.integer(checkpointInterval), seed,
as.integer(maxMemoryInMB), as.logical(cacheNodeIds))
new("DecisionTreeRegressionModel", jobj = jobj)
},
classification = {
if (is.null(impurity)) impurity <- "gini"
impurity <- match.arg(impurity, c("gini", "entropy"))
jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeClassifierWrapper",
"fit", data@sdf, formula, as.integer(maxDepth),
as.integer(maxBins), impurity,
as.integer(minInstancesPerNode), as.numeric(minInfoGain),
as.integer(checkpointInterval), seed,
as.integer(maxMemoryInMB), as.logical(cacheNodeIds))
new("DecisionTreeClassificationModel", jobj = jobj)
}
)
})

# Get the summary of a Decision Tree Regression Model

#' @return \code{summary} returns summary information of the fitted model, which is a list.
#' The list of components includes \code{formula} (formula),
#' \code{numFeatures} (number of features), \code{features} (list of features),
#' \code{featureImportances} (feature importances), and \code{maxDepth} (max depth of trees).
#' @rdname spark.decisionTree
#' @aliases summary,DecisionTreeRegressionModel-method
#' @export
#' @note summary(DecisionTreeRegressionModel) since 2.3.0
setMethod("summary", signature(object = "DecisionTreeRegressionModel"),
function(object) {
ans <- summary.decisionTree(object)
class(ans) <- "summary.DecisionTreeRegressionModel"
ans
})

# Prints the summary of Decision Tree Regression Model

#' @param x summary object of Decision Tree regression model or classification model
#' returned by \code{summary}.
#' @rdname spark.decisionTree
#' @export
#' @note print.summary.DecisionTreeRegressionModel since 2.3.0
print.summary.DecisionTreeRegressionModel <- function(x, ...) {
print.summary.decisionTree(x)
}

# Get the summary of a Decision Tree Classification Model

#' @rdname spark.decisionTree
#' @aliases summary,DecisionTreeClassificationModel-method
#' @export
#' @note summary(DecisionTreeClassificationModel) since 2.3.0
setMethod("summary", signature(object = "DecisionTreeClassificationModel"),
function(object) {
ans <- summary.decisionTree(object)
class(ans) <- "summary.DecisionTreeClassificationModel"
ans
})

# Prints the summary of Decision Tree Classification Model

#' @rdname spark.decisionTree
#' @export
#' @note print.summary.DecisionTreeClassificationModel since 2.3.0
print.summary.DecisionTreeClassificationModel <- function(x, ...) {
print.summary.decisionTree(x)
}

# Makes predictions from a Decision Tree Regression model or Classification model

#' @param newData a SparkDataFrame for testing.
#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named
#' "prediction".
#' @rdname spark.decisionTree
#' @aliases predict,DecisionTreeRegressionModel-method
#' @export
#' @note predict(DecisionTreeRegressionModel) since 2.3.0
setMethod("predict", signature(object = "DecisionTreeRegressionModel"),
function(object, newData) {
predict_internal(object, newData)
})

#' @rdname spark.decisionTree
#' @aliases predict,DecisionTreeClassificationModel-method
#' @export
#' @note predict(DecisionTreeClassificationModel) since 2.3.0
setMethod("predict", signature(object = "DecisionTreeClassificationModel"),
function(object, newData) {
predict_internal(object, newData)
})

# Save the Decision Tree Regression or Classification model to the input path.

#' @param object A fitted Decision Tree regression model or classification model.
#' @param path The directory where the model is saved.
#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
#' which means throw exception if the output path exists.
#'
#' @aliases write.ml,DecisionTreeRegressionModel,character-method
#' @rdname spark.decisionTree
#' @export
#' @note write.ml(DecisionTreeRegressionModel, character) since 2.3.0
setMethod("write.ml", signature(object = "DecisionTreeRegressionModel", path = "character"),
function(object, path, overwrite = FALSE) {
write_internal(object, path, overwrite)
})

#' @aliases write.ml,DecisionTreeClassificationModel,character-method
#' @rdname spark.decisionTree
#' @export
#' @note write.ml(DecisionTreeClassificationModel, character) since 2.3.0
setMethod("write.ml", signature(object = "DecisionTreeClassificationModel", path = "character"),
function(object, path, overwrite = FALSE) {
write_internal(object, path, overwrite)
})
14 changes: 10 additions & 4 deletions R/pkg/R/mllib_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@
#' @rdname write.ml
#' @name write.ml
#' @export
#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.gaussianMixture},
#' @seealso \link{spark.gbt}, \link{spark.glm}, \link{glm}, \link{spark.isoreg},
#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.decisionTree},
#' @seealso \link{spark.gaussianMixture}, \link{spark.gbt},
#' @seealso \link{spark.glm}, \link{glm}, \link{spark.isoreg},
#' @seealso \link{spark.kmeans},
#' @seealso \link{spark.lda}, \link{spark.logit},
#' @seealso \link{spark.mlp}, \link{spark.naiveBayes},
Expand All @@ -48,8 +49,9 @@ NULL
#' @rdname predict
#' @name predict
#' @export
#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.gaussianMixture},
#' @seealso \link{spark.gbt}, \link{spark.glm}, \link{glm}, \link{spark.isoreg},
#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.decisionTree},
#' @seealso \link{spark.gaussianMixture}, \link{spark.gbt},
#' @seealso \link{spark.glm}, \link{glm}, \link{spark.isoreg},
#' @seealso \link{spark.kmeans},
#' @seealso \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes},
#' @seealso \link{spark.randomForest}, \link{spark.survreg}, \link{spark.svmLinear}
Expand Down Expand Up @@ -110,6 +112,10 @@ read.ml <- function(path) {
new("RandomForestRegressionModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestClassifierWrapper")) {
new("RandomForestClassificationModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.DecisionTreeRegressorWrapper")) {
new("DecisionTreeRegressionModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.DecisionTreeClassifierWrapper")) {
new("DecisionTreeClassificationModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTRegressorWrapper")) {
new("GBTRegressionModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTClassifierWrapper")) {
Expand Down
Loading