Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R-package] enable saving Booster with saveRDS() and loading it with readRDS() (fixes #4296) #4685

Merged
merged 42 commits into from
Dec 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
4c177a5
idiomatic serialization
david-cortes Oct 14, 2021
ada260b
linter
david-cortes Oct 14, 2021
0d551a0
linter, namespace
david-cortes Oct 14, 2021
088eef3
comments, linter, fix failing test
david-cortes Oct 15, 2021
5e2a922
standardize error messages for null handles
david-cortes Oct 15, 2021
5c1d260
auto-restore handle in more functions
david-cortes Oct 15, 2021
8ed14a6
linter
david-cortes Oct 15, 2021
840de5e
missing declaration
david-cortes Oct 15, 2021
af16b2d
correct wrong signature
david-cortes Oct 15, 2021
9d7e6f8
fix docs
david-cortes Oct 15, 2021
4428a23
Update R-package/R/lgb.train.R
david-cortes Oct 15, 2021
730f2e6
Update R-package/R/lgb.drop_serialized.R
david-cortes Oct 15, 2021
719af93
Update R-package/R/lgb.restore_handle.R
david-cortes Oct 15, 2021
41a75bd
Update R-package/R/lgb.restore_handle.R
david-cortes Oct 15, 2021
9b5de4d
Update R-package/R/lgb.make_serializable.R
david-cortes Oct 15, 2021
1f4aa91
move 'restore_handle' from feature importance to dump method
david-cortes Oct 15, 2021
84af4e7
missing header
david-cortes Oct 15, 2021
25557f7
move arguments order, update docs
david-cortes Oct 15, 2021
ff78dd2
linter
david-cortes Oct 15, 2021
19f3c4a
avoid leaving files in working directory
david-cortes Oct 15, 2021
2f3a334
add test for save_model=NULL
david-cortes Oct 15, 2021
6e7b852
missing comma
david-cortes Oct 15, 2021
617b226
Update R-package/R/lgb.restore_handle.R
david-cortes Oct 16, 2021
8a078f4
Update R-package/src/lightgbm_R.cpp
david-cortes Oct 16, 2021
8e194af
change name of error function
david-cortes Oct 16, 2021
d4c8ef1
update comment
david-cortes Oct 16, 2021
44ca8db
restore old serialization functions but set as deprecated
david-cortes Oct 16, 2021
d6f4c74
Update R-package/R/readRDS.lgb.Booster.R
david-cortes Oct 17, 2021
8d282e4
Update R-package/R/saveRDS.lgb.Booster.R
david-cortes Oct 17, 2021
0817eb0
update docs
david-cortes Oct 17, 2021
f845554
Update R-package/R/readRDS.lgb.Booster.R
david-cortes Oct 26, 2021
51fa088
Update R-package/R/saveRDS.lgb.Booster.R
david-cortes Oct 26, 2021
8522ce7
Update R-package/tests/testthat/test_basic.R
david-cortes Oct 26, 2021
c116270
Update R-package/R/readRDS.lgb.Booster.R
david-cortes Oct 26, 2021
bee5bc1
comments
david-cortes Oct 26, 2021
b0f9f93
fix variable name
david-cortes Oct 26, 2021
2d3a132
restore serialization test for linear models
david-cortes Oct 26, 2021
c534952
Update R-package/R/lightgbm.R
david-cortes Nov 18, 2021
58fd21f
Merge branch 'master' into serial
david-cortes Nov 18, 2021
b1b4e2b
update docs
david-cortes Nov 18, 2021
eb7fd32
fix issues with null terminator
david-cortes Dec 4, 2021
34707ae
solve conflicts
david-cortes Dec 4, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions R-package/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,17 @@ export(lgb.Dataset.set.categorical)
export(lgb.Dataset.set.reference)
export(lgb.convert_with_rules)
export(lgb.cv)
export(lgb.drop_serialized)
export(lgb.dump)
export(lgb.get.eval.result)
export(lgb.importance)
export(lgb.interprete)
export(lgb.load)
export(lgb.make_serializable)
export(lgb.model.dt.tree)
export(lgb.plot.importance)
export(lgb.plot.interpretation)
export(lgb.restore_handle)
export(lgb.save)
export(lgb.train)
export(lgb.unloader)
Expand Down
72 changes: 59 additions & 13 deletions R-package/R/lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ Booster <- R6::R6Class(

} else if (!is.null(model_str)) {

# Do we have a model_str as character?
if (!is.character(model_str)) {
stop("lgb.Booster: Can only use a string as model_str")
# Do we have a model_str as character/raw?
if (!is.raw(model_str) && !is.character(model_str)) {
stop("lgb.Booster: Can only use a character/raw vector as model_str")
}

# Create booster from model
Expand Down Expand Up @@ -196,6 +196,8 @@ Booster <- R6::R6Class(
params <- utils::modifyList(params, additional_params)
params_str <- lgb.params2str(params = params)

self$restore_handle()

.Call(
LGBM_BoosterResetParameter_R
, private$handle
Expand Down Expand Up @@ -289,6 +291,8 @@ Booster <- R6::R6Class(
# Return one iteration behind
rollback_one_iter = function() {

self$restore_handle()

.Call(
LGBM_BoosterRollbackOneIter_R
, private$handle
Expand All @@ -306,6 +310,8 @@ Booster <- R6::R6Class(
# Get current iteration
current_iter = function() {

self$restore_handle()

cur_iter <- 0L
.Call(
LGBM_BoosterGetCurrentIteration_R
Expand All @@ -319,6 +325,8 @@ Booster <- R6::R6Class(
# Get upper bound
upper_bound = function() {

self$restore_handle()

upper_bound <- 0.0
.Call(
LGBM_BoosterGetUpperBoundValue_R
Expand All @@ -332,6 +340,8 @@ Booster <- R6::R6Class(
# Get lower bound
lower_bound = function() {

self$restore_handle()

lower_bound <- 0.0
.Call(
LGBM_BoosterGetLowerBoundValue_R
Expand Down Expand Up @@ -423,6 +433,8 @@ Booster <- R6::R6Class(
# Save model
save_model = function(filename, num_iteration = NULL, feature_importance_type = 0L) {

self$restore_handle()

if (is.null(num_iteration)) {
num_iteration <- self$best_iter
}
Expand All @@ -440,7 +452,9 @@ Booster <- R6::R6Class(
return(invisible(self))
},

save_model_to_string = function(num_iteration = NULL, feature_importance_type = 0L) {
save_model_to_string = function(num_iteration = NULL, feature_importance_type = 0L, as_char = TRUE) {

self$restore_handle()

if (is.null(num_iteration)) {
num_iteration <- self$best_iter
Expand All @@ -453,13 +467,19 @@ Booster <- R6::R6Class(
, as.integer(feature_importance_type)
)

if (as_char) {
model_str <- rawToChar(model_str)
}

return(model_str)

},

# Dump model in memory
dump_model = function(num_iteration = NULL, feature_importance_type = 0L) {

self$restore_handle()

if (is.null(num_iteration)) {
num_iteration <- self$best_iter
}
Expand Down Expand Up @@ -487,6 +507,8 @@ Booster <- R6::R6Class(
params = list(),
...) {

self$restore_handle()

additional_params <- list(...)
if (length(additional_params) > 0L) {
warning(paste0(
Expand Down Expand Up @@ -531,17 +553,39 @@ Booster <- R6::R6Class(
return(Predictor$new(modelfile = private$handle))
},

# Used for save
raw = NA,
# Used for serialization
raw = NULL,

# Save model to temporary file for in-memory saving
save = function() {
# Store serialized raw bytes in model object
save_raw = function() {
if (is.null(self$raw)) {
self$raw <- self$save_model_to_string(NULL, as_char = FALSE)
}
return(invisible(NULL))

# Overwrite model in object
self$raw <- self$save_model_to_string(NULL)
},

drop_raw = function() {
self$raw <- NULL
return(invisible(NULL))
},

check_null_handle = function() {
return(lgb.is.null.handle(private$handle))
},

restore_handle = function() {
if (self$check_null_handle()) {
if (is.null(self$raw)) {
.Call(LGBM_NullBoosterHandleError_R)
}
private$handle <- .Call(LGBM_BoosterLoadModelFromString_R, self$raw)
}
return(invisible(NULL))
},

get_handle = function() {
return(private$handle)
}

),
Expand Down Expand Up @@ -640,6 +684,8 @@ Booster <- R6::R6Class(
stop("data_idx should not be greater than num_dataset")
}

self$restore_handle()

private$get_eval_info()

ret <- list()
Expand Down Expand Up @@ -878,7 +924,7 @@ summary.lgb.Booster <- function(object, ...) {
#' @description Load LightGBM takes in either a file path or model string.
#' If both are provided, Load will default to loading from file
#' @param filename path of model file
#' @param model_str a str containing the model
#' @param model_str a str containing the model (as a `character` or `raw` vector)
#'
#' @return lgb.Booster
#'
Expand Down Expand Up @@ -928,8 +974,8 @@ lgb.load <- function(filename = NULL, model_str = NULL) {
}

if (model_str_provided) {
if (!is.character(model_str)) {
stop("lgb.load: model_str should be character")
if (!is.raw(model_str) && !is.character(model_str)) {
stop("lgb.load: model_str should be a character/raw vector")
}
return(invisible(Booster$new(model_str = model_str)))
}
Expand Down
7 changes: 6 additions & 1 deletion R-package/R/lgb.Predictor.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,17 @@ Predictor <- R6::R6Class(
)
private$need_free_handle <- TRUE

} else if (methods::is(modelfile, "lgb.Booster.handle")) {
} else if (methods::is(modelfile, "lgb.Booster.handle") || inherits(modelfile, "externalptr")) {

# Check if model file is a booster handle already
handle <- modelfile
private$need_free_handle <- FALSE

} else if (lgb.is.Booster(modelfile)) {

handle <- modelfile$get_handle()
private$need_free_handle <- FALSE

} else {

stop("lgb.Predictor: modelfile must be either a character filename or an lgb.Booster.handle")
Expand Down
5 changes: 5 additions & 0 deletions R-package/R/lgb.cv.R
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ lgb.cv <- function(params = list()
, early_stopping_rounds = NULL
, callbacks = list()
, reset_data = FALSE
, serializable = TRUE
, ...
) {

Expand Down Expand Up @@ -456,6 +457,10 @@ lgb.cv <- function(params = list()
})
}

if (serializable) {
lapply(cv_booster$boosters, function(model) model$booster$save_raw())
}

return(cv_booster)

}
Expand Down
18 changes: 18 additions & 0 deletions R-package/R/lgb.drop_serialized.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#' @name lgb.drop_serialized
#' @title Drop serialized raw bytes in a LightGBM model object
#' @description If a LightGBM model object was produced with argument `serializable=TRUE`, the R object will keep
#' a copy of the underlying C++ object as raw bytes, which can be used to reconstruct such object after getting
#' serialized and de-serialized, but at the cost of extra memory usage. If these raw bytes are not needed anymore,
#' they can be dropped through this function in order to save memory. Note that the object will be modified in-place.
#' @param model \code{lgb.Booster} object which was produced with `serializable=TRUE`.
#'
#' @return \code{lgb.Booster} (the same `model` object that was passed as input, as invisible).
#' @seealso \link{lgb.restore_handle}, \link{lgb.make_serializable}.
#' @export
lgb.drop_serialized <- function(model) {
if (!lgb.is.Booster(x = model)) {
stop("lgb.drop_serialized: model should be an ", sQuote("lgb.Booster"))
}
model$drop_raw()
return(invisible(model))
}
18 changes: 18 additions & 0 deletions R-package/R/lgb.make_serializable.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#' @name lgb.make_serializable
#' @title Make a LightGBM object serializable by keeping raw bytes
#' @description If a LightGBM model object was produced with argument `serializable=FALSE`, the R object will not
#' be serializable (e.g. cannot save and load with \code{saveRDS} and \code{readRDS}) as it will lack the raw bytes
#' needed to reconstruct its underlying C++ object. This function can be used to forcibly produce those serialized
#' raw bytes and make the object serializable. Note that the object will be modified in-place.
#' @param model \code{lgb.Booster} object which was produced with `serializable=FALSE`.
#'
#' @return \code{lgb.Booster} (the same `model` object that was passed as input, as invisible).
#' @seealso \link{lgb.restore_handle}, \link{lgb.drop_serialized}.
#' @export
lgb.make_serializable <- function(model) {
if (!lgb.is.Booster(x = model)) {
stop("lgb.make_serializable: model should be an ", sQuote("lgb.Booster"))
}
model$save_raw()
return(invisible(model))
}
36 changes: 36 additions & 0 deletions R-package/R/lgb.restore_handle.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#' @name lgb.restore_handle
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
#' @title Restore the C++ component of a de-serialized LightGBM model
#' @description After a LightGBM model object is de-serialized through functions such as \code{save} or
#' \code{saveRDS}, its underlying C++ object will be blank and needs to be restored to able to use it. Such
#' object is restored automatically when calling functions such as \code{predict}, but this function can be
#' used to forcibly restore it beforehand. Note that the object will be modified in-place.
#' @param model \code{lgb.Booster} object which was de-serialized and whose underlying C++ object and R handle
#' need to be restored.
#'
#' @return \code{lgb.Booster} (the same `model` object that was passed as input, invisibly).
#' @seealso \link{lgb.make_serializable}, \link{lgb.drop_serialized}.
#' @examples
#' library(lightgbm)
#' data("agaricus.train")
#' model <- lightgbm(
#' agaricus.train$data
#' , agaricus.train$label
#' , params = list(objective = "binary", nthreads = 1L)
#' , nrounds = 5L
#' , save_name = NULL
#' , verbose = 0)
#' fname <- tempfile(fileext="rds")
#' saveRDS(model, fname)
#'
#' model_new <- readRDS(fname)
#' model_new$check_null_handle()
#' lgb.restore_handle(model_new)
#' model_new$check_null_handle()
#' @export
lgb.restore_handle <- function(model) {
if (!lgb.is.Booster(x = model)) {
stop("lgb.restore_handle: model should be an ", sQuote("lgb.Booster"))
}
model$restore_handle()
return(invisible(model))
}
5 changes: 5 additions & 0 deletions R-package/R/lgb.train.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ lgb.train <- function(params = list(),
early_stopping_rounds = NULL,
callbacks = list(),
reset_data = FALSE,
serializable = TRUE,
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
...) {

# validate inputs early to avoid unnecessary computation
Expand Down Expand Up @@ -395,6 +396,10 @@ lgb.train <- function(params = list(),

}

if (serializable) {
booster$save_raw()
}

return(booster)

}
24 changes: 23 additions & 1 deletion R-package/R/lightgbm.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
#' @param params a list of parameters. See \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html}{
#' the "Parameters" section of the documentation} for a list of parameters and valid values.
#' @param verbose verbosity for output, if <= 0, also will disable the print of evaluation during training
#' @param serializable whether to make the resulting objects serializable through functions such as
#' \code{save} or \code{saveRDS} (see section "Model serialization").
#' @section Early Stopping:
#'
#' "early stopping" refers to stopping the training process if the model's performance on a given
Expand All @@ -66,6 +68,21 @@
#' in \code{params}, that metric will be considered the "first" one. If you omit \code{metric},
#' a default metric will be used based on your choice for the parameter \code{obj} (keyword argument)
#' or \code{objective} (passed into \code{params}).
#' @section Model serialization:
#'
#' LightGBM model objects can be serialized and de-serialized through functions such as \code{save}
#' or \code{saveRDS}, but similarly to libraries such as 'xgboost', serialization works a bit differently
#' from typical R objects. In order to make models serializable in R, a copy of the underlying C++ object
#' as serialized raw bytes is produced and stored in the R model object, and when this R object is
#' de-serialized, the underlying C++ model object gets reconstructed from these raw bytes, but will only
#' do so once some function that uses it is called, such as \code{predict}. In order to forcibly
#' reconstruct the C++ object after deserialization (e.g. after calling \code{readRDS} or similar), one
#' can use the function \link{lgb.restore_handle} (for example, if one makes predictions in parallel or in
#' forked processes, it will be faster to restore the handle beforehand).
#'
#' Producing and keeping these raw bytes however uses extra memory, and if they are not required,
#' it is possible to avoid producing them by passing `serializable=FALSE`. In such cases, these raw
#' bytes can be added to the model on demand through function \link{lgb.make_serializable}.
#' @keywords internal
NULL

Expand All @@ -76,6 +93,7 @@ NULL
#' @param label Vector of labels, used if \code{data} is not an \code{\link{lgb.Dataset}}
#' @param weight vector of response values. If not NULL, will set to dataset
#' @param save_name File name to use when writing the trained model to disk. Should end in ".model".
#' If passing `NULL`, will not save the trained model to disk.
#' @param ... Additional arguments passed to \code{\link{lgb.train}}. For example
#' \itemize{
#' \item{\code{valids}: a list of \code{lgb.Dataset} objects, used for validation}
Expand Down Expand Up @@ -113,6 +131,7 @@ lightgbm <- function(data,
save_name = "lightgbm.model",
init_model = NULL,
callbacks = list(),
serializable = TRUE,
...) {

# validate inputs early to avoid unnecessary computation
Expand All @@ -137,6 +156,7 @@ lightgbm <- function(data,
, "early_stopping_rounds" = early_stopping_rounds
, "init_model" = init_model
, "callbacks" = callbacks
, "serializable" = serializable
)
train_args <- append(train_args, list(...))

Expand All @@ -156,7 +176,9 @@ lightgbm <- function(data,
)

# Store model under a specific name
bst$save_model(filename = save_name)
if (!is.null(save_name)) {
bst$save_model(filename = save_name)
}

return(bst)
}
Expand Down
Loading