Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

* `surv_reg()` is now defunct and will error if called. Please use `survival_reg()` instead (#1206).

* Enable parsnip to work with xgboost version > 2.0.0.0. (#1227)

# parsnip 1.3.3

Expand Down
139 changes: 110 additions & 29 deletions R/boost_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ xgb_train <- function(
event_level = c("first", "second"),
...
) {
rlang::check_installed("xgboost")
event_level <- rlang::arg_match(event_level, c("first", "second"))
others <- list(...)

Expand Down Expand Up @@ -340,31 +341,70 @@ xgb_train <- function(

others <- process_others(others, arg_list)

if (utils::packageVersion("xgboost") >= "2.0.0.0") {
if (!is.null(num_class) && num_class > 2) {
arg_list$num_class <- num_class
}

param_names <- names(
formals(
getFromNamespace("xgb.params", ns = "xgboost")
)
)

if (any(param_names %in% names(others))) {
elements <- param_names[param_names %in% names(others)]

for (element in elements) {
arg_list[[element]] <- others[[element]]
others[[element]] <- NULL
}
}

if (is.null(arg_list$objective)) {
if (is.numeric(y)) {
arg_list$objective <- "reg:squarederror"
} else {
if (num_class == 2) {
arg_list$objective <- "binary:logistic"
} else {
arg_list$objective <- "multi:softprob"
}
}
}
}

main_args <- c(
list(
data = quote(x$data),
watchlist = quote(x$watchlist),
params = arg_list,
nrounds = nrounds,
early_stopping_rounds = early_stop
),
others
)
if (utils::packageVersion("xgboost") >= "2.0.0.0") {
main_args$evals <- quote(x$watchlist)
} else {
main_args$watchlist <- quote(x$watchlist)
}

if (is.null(main_args$objective)) {
if (is.numeric(y)) {
main_args$objective <- "reg:squarederror"
} else {
if (num_class == 2) {
main_args$objective <- "binary:logistic"
if (utils::packageVersion("xgboost") < "2.0.0.0") {
if (is.null(main_args$objective)) {
if (is.numeric(y)) {
main_args$objective <- "reg:squarederror"
} else {
main_args$objective <- "multi:softprob"
if (num_class == 2) {
main_args$objective <- "binary:logistic"
} else {
main_args$objective <- "multi:softprob"
}
}
}
}

if (!is.null(num_class) && num_class > 2) {
main_args$num_class <- num_class
if (!is.null(num_class) && num_class > 2) {
main_args$num_class <- num_class
}
}

call <- make_call(fun = "xgb.train", ns = "xgboost", main_args)
Expand Down Expand Up @@ -471,6 +511,7 @@ as_xgb_data <- function(
event_level = "first",
...
) {
rlang::check_installed("xgboost")
lvls <- levels(y)
n <- nrow(x)

Expand Down Expand Up @@ -506,21 +547,52 @@ as_xgb_data <- function(
watch_list <- list(validation = val_data)

info_list <- list(label = y[trn_index])
if (!is.null(weights)) {
info_list$weight <- weights[trn_index]
if (utils::packageVersion("xgboost") >= "2.0.0.0") {
if (!is.null(weights)) {
dat <- xgboost::xgb.DMatrix(
data = x[trn_index, , drop = FALSE],
missing = NA,
label = y[trn_index],
weight = weights[trn_index]
)
} else {
dat <- xgboost::xgb.DMatrix(
data = x[trn_index, , drop = FALSE],
missing = NA,
label = y[trn_index]
)
}
} else {
if (!is.null(weights)) {
info_list$weight <- weights[trn_index]
}
dat <- xgboost::xgb.DMatrix(
data = x[trn_index, , drop = FALSE],
missing = NA,
info = info_list
)
}
dat <- xgboost::xgb.DMatrix(
data = x[trn_index, , drop = FALSE],
missing = NA,
info = info_list
)
} else {
info_list <- list(label = y)
if (!is.null(weights)) {
info_list$weight <- weights
if (utils::packageVersion("xgboost") >= "2.0.0.0") {
if (!is.null(weights)) {
dat <- xgboost::xgb.DMatrix(
x,
missing = NA,
label = y,
weight = weights
)
} else {
dat <- xgboost::xgb.DMatrix(x, missing = NA, label = y)
}
watch_list <- list(training = dat)
} else {
info_list <- list(label = y)
if (!is.null(weights)) {
info_list$weight <- weights
}
dat <- xgboost::xgb.DMatrix(x, missing = NA, info = info_list)
watch_list <- list(training = dat)
}
dat <- xgboost::xgb.DMatrix(x, missing = NA, info = info_list)
watch_list <- list(training = dat)
}
} else {
dat <- xgboost::setinfo(x, "label", y)
Expand Down Expand Up @@ -579,12 +651,21 @@ multi_predict._xgb.Booster <-
}

xgb_by_tree <- function(tree, object, new_data, type, ...) {
pred <- xgb_predict(
object$fit,
new_data = new_data,
iterationrange = c(1, tree + 1),
ntreelimit = NULL
)
rlang::check_installed("xgboost")
if (utils::packageVersion("xgboost") >= "2.0.0.0") {
pred <- xgb_predict(
object$fit,
new_data = new_data,
iterationrange = c(1, tree + 1)
)
} else {
pred <- xgb_predict(
object$fit,
new_data = new_data,
iterationrange = c(1, tree + 1),
ntreelimit = NULL
)
}

# switch based on prediction type
if (object$spec$mode == "regression") {
Expand Down
Loading