Skip to content

Commit

Permalink
Callback, vec.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jun 7, 2021
1 parent d06786e commit baf9187
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
7 changes: 3 additions & 4 deletions R-package/R/callbacks.R
Original file line number Diff line number Diff line change
Expand Up @@ -498,13 +498,12 @@ cb.cv.predict <- function(save_models = FALSE) {
rep(NA_real_, N)
}

ntreelimit <- NVL(env$basket$best_ntreelimit,
env$end_iteration * env$num_parallel_tree)
iterationrange <- NVL(env$basket$best_iteration, env$end_iteration)
if (NVL(env$params[['booster']], '') == 'gblinear') {
ntreelimit <- 0 # must be 0 for gblinear
iterationrange <- c(0, 0) # must be 0 for gblinear
}
for (fd in env$bst_folds) {
pr <- predict(fd$bst, fd$watchlist[[2]], ntreelimit = ntreelimit, reshape = TRUE)
pr <- predict(fd$bst, fd$watchlist[[2]], iterationrange = iterationrange, reshape = TRUE)
if (is.matrix(pred)) {
pred[fd$index, ] <- pr
} else {
Expand Down
12 changes: 8 additions & 4 deletions R-package/R/xgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -345,21 +345,25 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
ntreelimit <- 0
if (ntreelimit != 0 && is.null(iterationrange)) {
## only ntreelimit, initialize iteration range
iterationrange = list(begin = 0, end = 0)
iterationrange = c(0, 0)
} else if (ntreelimit == 0 && !is.null(iterationrange)) {
## only iteration range, do nothing
} else if (ntreelimit != 0 && !is.null(iterationrange)) {
## both are specified, let libgxgboost throw an error
} else {
## no limit is supplied, use best
iterationrange = list(begin = 0, end = NVL(object$best_iteration, 0))
if (is.null(object$best_iteration)) {
iterationrange = c(0, 0)
} else {
iterationrange = c(0, object$best_iteration + 1)
}
}

args <- list(
training = training,
strict_shape = FALSE,
iteration_begin = iterationrange$begin,
iteration_end = iterationrange$end,
iteration_begin = iterationrange[0],
iteration_end = iterationrange[1],
ntree_limit = ntreelimit,
type = 0
)
Expand Down

0 comments on commit baf9187

Please sign in to comment.