Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# parsnip (development version)

* Enable generalized random forest (`grf`) models for classification, regression, and quantile regression modes. (#1288)

* `surv_reg()` is now defunct and will error if called. Please use `survival_reg()` instead (#1206).


# parsnip 1.3.3

* Bug fix in how tunable parameters were configured for brulee neural networks.
Expand Down
8 changes: 6 additions & 2 deletions R/aaa_archive.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# no fmt
# fmt: skip
model_info_table <-
tibble::tribble(
~model, ~mode, ~engine, ~pkg,
Expand All @@ -21,6 +21,7 @@ model_info_table <-
"bag_tree", "classification", "rpart", "baguette",
"bart", "classification", "dbarts", NA,
"boost_tree", "classification", "C5.0", NA,
"boost_tree", "classification", "catboost", "bonsai",
"boost_tree", "classification", "h2o", "agua",
"boost_tree", "classification", "h2o_gbm", "agua",
"boost_tree", "classification", "lightgbm", "bonsai",
Expand Down Expand Up @@ -69,6 +70,7 @@ model_info_table <-
"null_model", "classification", "parsnip", NA,
"pls", "classification", "mixOmics", "plsmod",
"rand_forest", "classification", "aorsf", "bonsai",
"rand_forest", "classification", "grf", NA,
"rand_forest", "classification", "h2o", "agua",
"rand_forest", "classification", "partykit", "bonsai",
"rand_forest", "classification", "randomForest", NA,
Expand All @@ -82,11 +84,13 @@ model_info_table <-
"svm_rbf", "classification", "kernlab", NA,
"svm_rbf", "classification", "liquidSVM", NA,
"linear_reg", "quantile regression", "quantreg", NA,
"rand_forest", "quantile regression", "grf", NA,
"auto_ml", "regression", "h2o", "agua",
"bag_mars", "regression", "earth", "baguette",
"bag_mlp", "regression", "nnet", "baguette",
"bag_tree", "regression", "rpart", "baguette",
"bart", "regression", "dbarts", NA,
"boost_tree", "regression", "catboost", "bonsai",
"boost_tree", "regression", "h2o", "agua",
"boost_tree", "regression", "h2o_gbm", "agua",
"boost_tree", "regression", "lightgbm", "bonsai",
Expand Down Expand Up @@ -130,6 +134,7 @@ model_info_table <-
"poisson_reg", "regression", "stan_glmer", "multilevelmod",
"poisson_reg", "regression", "zeroinfl", "poissonreg",
"rand_forest", "regression", "aorsf", "bonsai",
"rand_forest", "regression", "grf", NA,
"rand_forest", "regression", "h2o", "agua",
"rand_forest", "regression", "partykit", "bonsai",
"rand_forest", "regression", "randomForest", NA,
Expand All @@ -145,4 +150,3 @@ model_info_table <-
"svm_rbf", "regression", "kernlab", NA,
"svm_rbf", "regression", "liquidSVM", NA
)

20 changes: 11 additions & 9 deletions R/augment.R
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,17 @@
#'
#' # ------------------------------------------------------------------------------
#'
#' # Quantile regression example
#' qr_form <-
#' linear_reg() |>
#' set_engine("quantreg") |>
#' set_mode("quantile regression", quantile_levels = c(0.25, 0.5, 0.75)) |>
#' fit(mpg ~ ., data = car_trn)
#'
#' augment(qr_form, car_tst)
#' augment(qr_form, car_tst[, -1])
#' if (rlang::is_installed("quantreg")) {
#' # Quantile regression example
#' qr_form <-
#' linear_reg() |>
#' set_engine("quantreg") |>
#' set_mode("quantile regression", quantile_levels = c(0.25, 0.5, 0.75)) |>
#' fit(mpg ~ ., data = car_trn)
#'
#' augment(qr_form, car_tst)
#' augment(qr_form, car_tst[, -1])
#' }
#'
augment.model_fit <- function(x, new_data, eval_time = NULL, ...) {
new_data <- tibble::new_tibble(new_data)
Expand Down
139 changes: 71 additions & 68 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,13 @@
#' @export
#' @export fit.model_spec
fit.model_spec <-
function(object,
formula,
data,
case_weights = NULL,
control = control_parsnip(),
...
function(
object,
formula,
data,
case_weights = NULL,
control = control_parsnip(),
...
) {
if (object$mode == "unknown") {
cli::cli_abort(
Expand All @@ -135,7 +136,6 @@ fit.model_spec <-
}
check_formula(formula)


if (is_sparse_matrix(data)) {
data <- sparsevctrs::coerce_to_sparse_tibble(data, rlang::caller_env(0))
}
Expand All @@ -153,12 +153,14 @@ fit.model_spec <-
eng_vals <- possible_engines(object)
object$engine <- eng_vals[1]
if (control$verbosity > 0) {
cli::cli_warn("Engine set to {.val {object$engine}}.")
cli::cli_warn("Engine set to {.val {object$engine}}.")
}
}

if (all(c("x", "y") %in% names(dots))) {
cli::cli_abort("{.fn fit.model_spec} is for the formula methods. Use {.fn fit_xy} instead.")
cli::cli_abort(
"{.fn fit.model_spec} is for the formula methods. Use {.fn fit_xy} instead."
)
}
cl <- match.call(expand.dots = TRUE)
# Create an environment with the evaluated argument objects. This will be
Expand Down Expand Up @@ -186,11 +188,12 @@ fit.model_spec <-
fit_interface <-
check_interface(eval_env$formula, eval_env$data, cl, object)

if (object$engine == "spark" && !inherits(eval_env$data, "tbl_spark"))
if (object$engine == "spark" && !inherits(eval_env$data, "tbl_spark")) {
cli::cli_abort(
"spark objects can only be used with the formula interface to {.fn fit}
"spark objects can only be used with the formula interface to {.fn fit}
with a spark data object."
)
)
}

# populate `method` with the details for this model type
object <- add_methods(object, engine = object$engine)
Expand All @@ -208,51 +211,49 @@ fit.model_spec <-
switch(
interfaces,
# homogeneous combinations:
formula_formula =
form_form(
object = object,
control = control,
env = eval_env
),
formula_formula = form_form(
object = object,
control = control,
env = eval_env
),

# heterogenous combinations
formula_matrix =
form_xy(
object = object,
control = control,
env = eval_env,
target = object$method$fit$interface,
...
),
formula_data.frame =
form_xy(
object = object,
control = control,
env = eval_env,
target = object$method$fit$interface,
...
),
formula_matrix = form_xy(
object = object,
control = control,
env = eval_env,
target = object$method$fit$interface,
...
),
formula_data.frame = form_xy(
object = object,
control = control,
env = eval_env,
target = object$method$fit$interface,
...
),

cli::cli_abort("{.val {interfaces}} is unknown.")
)
res$censor_probs <- reverse_km(object, eval_env)
model_classes <- class(res$fit)
class(res) <- c(paste0("_", model_classes[1]), "model_fit")
res
}
}

# ------------------------------------------------------------------------------

#' @rdname fit
#' @export
#' @export fit_xy.model_spec
fit_xy.model_spec <-
function(object,
x,
y,
case_weights = NULL,
control = control_parsnip(),
...
function(
object,
x,
y,
case_weights = NULL,
control = control_parsnip(),
...
) {
if (object$mode == "unknown") {
cli::cli_abort(
Expand Down Expand Up @@ -329,32 +330,32 @@ fit_xy.model_spec <-
switch(
interfaces,
# homogeneous combinations:
matrix_matrix = , data.frame_matrix =
xy_xy(
object = object,
env = eval_env,
control = control,
target = "matrix",
...
),

data.frame_data.frame = , matrix_data.frame =
xy_xy(
object = object,
env = eval_env,
control = control,
target = "data.frame",
...
),
matrix_matrix = ,
data.frame_matrix = xy_xy(
object = object,
env = eval_env,
control = control,
target = "matrix",
...
),

data.frame_data.frame = ,
matrix_data.frame = xy_xy(
object = object,
env = eval_env,
control = control,
target = "data.frame",
...
),

# heterogenous combinations
matrix_formula = , data.frame_formula =
xy_form(
object = object,
env = eval_env,
control = control,
...
),
matrix_formula = ,
data.frame_formula = xy_form(
object = object,
env = eval_env,
control = control,
...
),
cli::cli_abort("{.val {interfaces}} is unknown.")
)
res$censor_probs <- reverse_km(object, eval_env)
Expand All @@ -368,7 +369,9 @@ fit_xy.model_spec <-
eval_mod <- function(e, capture = FALSE, catch = FALSE, envir = NULL, ...) {
if (capture) {
if (catch) {
junk <- capture.output(res <- try(eval_tidy(e, env = envir, ...), silent = TRUE))
junk <- capture.output(
res <- try(eval_tidy(e, env = envir, ...), silent = TRUE)
)
} else {
junk <- capture.output(res <- eval_tidy(e, env = envir, ...))
}
Expand All @@ -391,13 +394,13 @@ check_interface <- function(formula, data, cl, model, call = caller_env()) {
# Determine the `fit()` interface
form_interface <- !is.null(formula) & !is.null(data)

if (form_interface)
if (form_interface) {
return("formula")
}
cli::cli_abort("Error when checking the interface.", call = call)
}

check_xy_interface <- function(x, y, cl, model, call = caller_env()) {

sparse_ok <- allow_sparse(model)
sparse_x <- inherits(x, "dgCMatrix")
if (!sparse_ok & sparse_x) {
Expand Down
43 changes: 26 additions & 17 deletions R/rand_forest.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,17 @@
#' @export

rand_forest <-
function(mode = "unknown", engine = "ranger", mtry = NULL, trees = NULL, min_n = NULL) {

function(
mode = "unknown",
engine = "ranger",
mtry = NULL,
trees = NULL,
min_n = NULL
) {
args <- list(
mtry = enquo(mtry),
trees = enquo(trees),
min_n = enquo(min_n)
mtry = enquo(mtry),
trees = enquo(trees),
min_n = enquo(min_n)
)

new_model_spec(
Expand All @@ -60,15 +65,19 @@ rand_forest <-
#' @rdname parsnip_update
#' @export
update.rand_forest <-
function(object,
parameters = NULL,
mtry = NULL, trees = NULL, min_n = NULL,
fresh = FALSE, ...) {

function(
object,
parameters = NULL,
mtry = NULL,
trees = NULL,
min_n = NULL,
fresh = FALSE,
...
) {
args <- list(
mtry = enquo(mtry),
trees = enquo(trees),
min_n = enquo(min_n)
mtry = enquo(mtry),
trees = enquo(trees),
min_n = enquo(min_n)
)

update_spec(
Expand Down Expand Up @@ -109,16 +118,17 @@ translate.rand_forest <- function(x, engine = x$engine, ...) {

# See "Details" in ?ml_random_forest_classifier. `feature_subset_strategy`
# should be character even if it contains a number.
if (any(names(arg_vals) == "feature_subset_strategy") &&
isTRUE(is.numeric(quo_get_expr(arg_vals$feature_subset_strategy)))) {
if (
any(names(arg_vals) == "feature_subset_strategy") &&
isTRUE(is.numeric(quo_get_expr(arg_vals$feature_subset_strategy)))
) {
arg_vals$feature_subset_strategy <-
paste(quo_get_expr(arg_vals$feature_subset_strategy))
}
}

# add checks to error trap or change things for this method
if (engine == "ranger") {

if (any(names(arg_vals) == "importance")) {
if (isTRUE(is.logical(quo_get_expr(arg_vals$importance)))) {
cli::cli_abort(
Expand Down Expand Up @@ -170,4 +180,3 @@ check_args.rand_forest <- function(object, call = rlang::caller_env()) {
# move translate checks here?
invisible(object)
}

Loading
Loading