diff --git a/R-package/R/callbacks.R b/R-package/R/callbacks.R index 7c0df2f1dabc..c7bc3539b9db 100644 --- a/R-package/R/callbacks.R +++ b/R-package/R/callbacks.R @@ -498,13 +498,12 @@ cb.cv.predict <- function(save_models = FALSE) { rep(NA_real_, N) } - ntreelimit <- NVL(env$basket$best_ntreelimit, - env$end_iteration * env$num_parallel_tree) + iterationrange <- NVL(env$basket$best_iteration, env$end_iteration) if (NVL(env$params[['booster']], '') == 'gblinear') { - ntreelimit <- 0 # must be 0 for gblinear + iterationrange <- c(0, 0) # must be 0 for gblinear } for (fd in env$bst_folds) { - pr <- predict(fd$bst, fd$watchlist[[2]], ntreelimit = ntreelimit, reshape = TRUE) + pr <- predict(fd$bst, fd$watchlist[[2]], iterationrange = iterationrange, reshape = TRUE) if (is.matrix(pred)) { pred[fd$index, ] <- pr } else { diff --git a/R-package/R/xgb.Booster.R b/R-package/R/xgb.Booster.R index 919a442da1f1..fdbbfee85c08 100644 --- a/R-package/R/xgb.Booster.R +++ b/R-package/R/xgb.Booster.R @@ -345,21 +345,25 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA ntreelimit <- 0 if (ntreelimit != 0 && is.null(iterationrange)) { ## only ntreelimit, initialize iteration range - iterationrange = list(begin = 0, end = 0) + iterationrange = c(0, 0) } else if (ntreelimit == 0 && !is.null(iterationrange)) { ## only iteration range, do nothing } else if (ntreelimit != 0 && !is.null(iterationrange)) { ## both are specified, let libgxgboost throw an error } else { ## no limit is supplied, use best - iterationrange = list(begin = 0, end = NVL(object$best_iteration, 0)) + if (is.null(object$best_iteration)) { + iterationrange = c(0, 0) + } else { + iterationrange = c(0, object$best_iteration + 1) + } } args <- list( training = training, strict_shape = FALSE, - iteration_begin = iterationrange$begin, - iteration_end = iterationrange$end, + iteration_begin = iterationrange[0], + iteration_end = iterationrange[1], ntree_limit = ntreelimit, type = 0 )