Skip to content

Commit

Permalink
[R-package] enable saving Booster with saveRDS() and loading it with …
Browse files Browse the repository at this point in the history
…readRDS() (fixes #4296) (#4685)

* idiomatic serialization

* linter

* linter, namespace

* comments, linter, fix failing test

* standardize error messages for null handles

* auto-restore handle in more functions

* linter

* missing declaration

* correct wrong signature

* fix docs

* Update R-package/R/lgb.train.R

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* Update R-package/R/lgb.drop_serialized.R

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* Update R-package/R/lgb.restore_handle.R

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* Update R-package/R/lgb.restore_handle.R

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* Update R-package/R/lgb.make_serializable.R

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* move 'restore_handle' from feature importance to dump method

* missing header

* move arguments order, update docs

* linter

* avoid leaving files in working directory

* add test for save_model=NULL

* missing comma

* Update R-package/R/lgb.restore_handle.R

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

* Update R-package/src/lightgbm_R.cpp

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

* change name of error function

* update comment

* restore old serialization functions but set as deprecated

* Update R-package/R/readRDS.lgb.Booster.R

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

* Update R-package/R/saveRDS.lgb.Booster.R

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

* update docs

* Update R-package/R/readRDS.lgb.Booster.R

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* Update R-package/R/saveRDS.lgb.Booster.R

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* Update R-package/tests/testthat/test_basic.R

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* Update R-package/R/readRDS.lgb.Booster.R

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* comments

* fix variable name

* restore serialization test for linear models

* Update R-package/R/lightgbm.R

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* update docs

* fix issues with null terminator

Co-authored-by: James Lamb <jaylamb20@gmail.com>
Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
  • Loading branch information
3 people authored Dec 4, 2021
1 parent f54e32f commit 12b1527
Show file tree
Hide file tree
Showing 25 changed files with 495 additions and 97 deletions.
3 changes: 3 additions & 0 deletions R-package/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,17 @@ export(lgb.Dataset.set.categorical)
export(lgb.Dataset.set.reference)
export(lgb.convert_with_rules)
export(lgb.cv)
export(lgb.drop_serialized)
export(lgb.dump)
export(lgb.get.eval.result)
export(lgb.importance)
export(lgb.interprete)
export(lgb.load)
export(lgb.make_serializable)
export(lgb.model.dt.tree)
export(lgb.plot.importance)
export(lgb.plot.interpretation)
export(lgb.restore_handle)
export(lgb.save)
export(lgb.train)
export(lgb.unloader)
Expand Down
72 changes: 59 additions & 13 deletions R-package/R/lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ Booster <- R6::R6Class(

} else if (!is.null(model_str)) {

# Do we have a model_str as character?
if (!is.character(model_str)) {
stop("lgb.Booster: Can only use a string as model_str")
# Do we have a model_str as character/raw?
if (!is.raw(model_str) && !is.character(model_str)) {
stop("lgb.Booster: Can only use a character/raw vector as model_str")
}

# Create booster from model
Expand Down Expand Up @@ -196,6 +196,8 @@ Booster <- R6::R6Class(
params <- utils::modifyList(params, additional_params)
params_str <- lgb.params2str(params = params)

self$restore_handle()

.Call(
LGBM_BoosterResetParameter_R
, private$handle
Expand Down Expand Up @@ -289,6 +291,8 @@ Booster <- R6::R6Class(
# Return one iteration behind
rollback_one_iter = function() {

self$restore_handle()

.Call(
LGBM_BoosterRollbackOneIter_R
, private$handle
Expand All @@ -306,6 +310,8 @@ Booster <- R6::R6Class(
# Get current iteration
current_iter = function() {

self$restore_handle()

cur_iter <- 0L
.Call(
LGBM_BoosterGetCurrentIteration_R
Expand All @@ -319,6 +325,8 @@ Booster <- R6::R6Class(
# Get upper bound
upper_bound = function() {

self$restore_handle()

upper_bound <- 0.0
.Call(
LGBM_BoosterGetUpperBoundValue_R
Expand All @@ -332,6 +340,8 @@ Booster <- R6::R6Class(
# Get lower bound
lower_bound = function() {

self$restore_handle()

lower_bound <- 0.0
.Call(
LGBM_BoosterGetLowerBoundValue_R
Expand Down Expand Up @@ -423,6 +433,8 @@ Booster <- R6::R6Class(
# Save model
save_model = function(filename, num_iteration = NULL, feature_importance_type = 0L) {

self$restore_handle()

if (is.null(num_iteration)) {
num_iteration <- self$best_iter
}
Expand All @@ -440,7 +452,9 @@ Booster <- R6::R6Class(
return(invisible(self))
},

save_model_to_string = function(num_iteration = NULL, feature_importance_type = 0L) {
save_model_to_string = function(num_iteration = NULL, feature_importance_type = 0L, as_char = TRUE) {

self$restore_handle()

if (is.null(num_iteration)) {
num_iteration <- self$best_iter
Expand All @@ -453,13 +467,19 @@ Booster <- R6::R6Class(
, as.integer(feature_importance_type)
)

if (as_char) {
model_str <- rawToChar(model_str)
}

return(model_str)

},

# Dump model in memory
dump_model = function(num_iteration = NULL, feature_importance_type = 0L) {

self$restore_handle()

if (is.null(num_iteration)) {
num_iteration <- self$best_iter
}
Expand Down Expand Up @@ -487,6 +507,8 @@ Booster <- R6::R6Class(
params = list(),
...) {

self$restore_handle()

additional_params <- list(...)
if (length(additional_params) > 0L) {
warning(paste0(
Expand Down Expand Up @@ -531,17 +553,39 @@ Booster <- R6::R6Class(
return(Predictor$new(modelfile = private$handle))
},

# Used for save
raw = NA,
# Used for serialization
raw = NULL,

# Save model to temporary file for in-memory saving
save = function() {
# Store serialized raw bytes in model object
save_raw = function() {
if (is.null(self$raw)) {
self$raw <- self$save_model_to_string(NULL, as_char = FALSE)
}
return(invisible(NULL))

# Overwrite model in object
self$raw <- self$save_model_to_string(NULL)
},

drop_raw = function() {
self$raw <- NULL
return(invisible(NULL))
},

check_null_handle = function() {
return(lgb.is.null.handle(private$handle))
},

restore_handle = function() {
if (self$check_null_handle()) {
if (is.null(self$raw)) {
.Call(LGBM_NullBoosterHandleError_R)
}
private$handle <- .Call(LGBM_BoosterLoadModelFromString_R, self$raw)
}
return(invisible(NULL))
},

get_handle = function() {
return(private$handle)
}

),
Expand Down Expand Up @@ -640,6 +684,8 @@ Booster <- R6::R6Class(
stop("data_idx should not be greater than num_dataset")
}

self$restore_handle()

private$get_eval_info()

ret <- list()
Expand Down Expand Up @@ -878,7 +924,7 @@ summary.lgb.Booster <- function(object, ...) {
#' @description Load LightGBM takes in either a file path or model string.
#' If both are provided, Load will default to loading from file
#' @param filename path of model file
#' @param model_str a str containing the model
#' @param model_str a str containing the model (as a `character` or `raw` vector)
#'
#' @return lgb.Booster
#'
Expand Down Expand Up @@ -928,8 +974,8 @@ lgb.load <- function(filename = NULL, model_str = NULL) {
}

if (model_str_provided) {
if (!is.character(model_str)) {
stop("lgb.load: model_str should be character")
if (!is.raw(model_str) && !is.character(model_str)) {
stop("lgb.load: model_str should be a character/raw vector")
}
return(invisible(Booster$new(model_str = model_str)))
}
Expand Down
7 changes: 6 additions & 1 deletion R-package/R/lgb.Predictor.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,17 @@ Predictor <- R6::R6Class(
)
private$need_free_handle <- TRUE

} else if (methods::is(modelfile, "lgb.Booster.handle")) {
} else if (methods::is(modelfile, "lgb.Booster.handle") || inherits(modelfile, "externalptr")) {

# Check if model file is a booster handle already
handle <- modelfile
private$need_free_handle <- FALSE

} else if (lgb.is.Booster(modelfile)) {

handle <- modelfile$get_handle()
private$need_free_handle <- FALSE

} else {

stop("lgb.Predictor: modelfile must be either a character filename or an lgb.Booster.handle")
Expand Down
5 changes: 5 additions & 0 deletions R-package/R/lgb.cv.R
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ lgb.cv <- function(params = list()
, early_stopping_rounds = NULL
, callbacks = list()
, reset_data = FALSE
, serializable = TRUE
, ...
) {

Expand Down Expand Up @@ -456,6 +457,10 @@ lgb.cv <- function(params = list()
})
}

if (serializable) {
lapply(cv_booster$boosters, function(model) model$booster$save_raw())
}

return(cv_booster)

}
Expand Down
18 changes: 18 additions & 0 deletions R-package/R/lgb.drop_serialized.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#' @name lgb.drop_serialized
#' @title Drop serialized raw bytes in a LightGBM model object
#' @description If a LightGBM model object was produced with argument `serializable=TRUE`, the R object will keep
#' a copy of the underlying C++ object as raw bytes, which can be used to reconstruct such object after getting
#' serialized and de-serialized, but at the cost of extra memory usage. If these raw bytes are not needed anymore,
#' they can be dropped through this function in order to save memory. Note that the object will be modified in-place.
#' @param model \code{lgb.Booster} object which was produced with `serializable=TRUE`.
#'
#' @return \code{lgb.Booster} (the same `model` object that was passed as input, as invisible).
#' @seealso \link{lgb.restore_handle}, \link{lgb.make_serializable}.
#' @export
lgb.drop_serialized <- function(model) {
if (!lgb.is.Booster(x = model)) {
stop("lgb.drop_serialized: model should be an ", sQuote("lgb.Booster"))
}
model$drop_raw()
return(invisible(model))
}
18 changes: 18 additions & 0 deletions R-package/R/lgb.make_serializable.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#' @name lgb.make_serializable
#' @title Make a LightGBM object serializable by keeping raw bytes
#' @description If a LightGBM model object was produced with argument `serializable=FALSE`, the R object will not
#' be serializable (e.g. cannot save and load with \code{saveRDS} and \code{readRDS}) as it will lack the raw bytes
#' needed to reconstruct its underlying C++ object. This function can be used to forcibly produce those serialized
#' raw bytes and make the object serializable. Note that the object will be modified in-place.
#' @param model \code{lgb.Booster} object which was produced with `serializable=FALSE`.
#'
#' @return \code{lgb.Booster} (the same `model` object that was passed as input, as invisible).
#' @seealso \link{lgb.restore_handle}, \link{lgb.drop_serialized}.
#' @export
lgb.make_serializable <- function(model) {
if (!lgb.is.Booster(x = model)) {
stop("lgb.make_serializable: model should be an ", sQuote("lgb.Booster"))
}
model$save_raw()
return(invisible(model))
}
36 changes: 36 additions & 0 deletions R-package/R/lgb.restore_handle.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#' @name lgb.restore_handle
#' @title Restore the C++ component of a de-serialized LightGBM model
#' @description After a LightGBM model object is de-serialized through functions such as \code{save} or
#' \code{saveRDS}, its underlying C++ object will be blank and needs to be restored to able to use it. Such
#' object is restored automatically when calling functions such as \code{predict}, but this function can be
#' used to forcibly restore it beforehand. Note that the object will be modified in-place.
#' @param model \code{lgb.Booster} object which was de-serialized and whose underlying C++ object and R handle
#' need to be restored.
#'
#' @return \code{lgb.Booster} (the same `model` object that was passed as input, invisibly).
#' @seealso \link{lgb.make_serializable}, \link{lgb.drop_serialized}.
#' @examples
#' library(lightgbm)
#' data("agaricus.train")
#' model <- lightgbm(
#' agaricus.train$data
#' , agaricus.train$label
#' , params = list(objective = "binary", nthreads = 1L)
#' , nrounds = 5L
#' , save_name = NULL
#' , verbose = 0)
#' fname <- tempfile(fileext="rds")
#' saveRDS(model, fname)
#'
#' model_new <- readRDS(fname)
#' model_new$check_null_handle()
#' lgb.restore_handle(model_new)
#' model_new$check_null_handle()
#' @export
lgb.restore_handle <- function(model) {
if (!lgb.is.Booster(x = model)) {
stop("lgb.restore_handle: model should be an ", sQuote("lgb.Booster"))
}
model$restore_handle()
return(invisible(model))
}
5 changes: 5 additions & 0 deletions R-package/R/lgb.train.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ lgb.train <- function(params = list(),
early_stopping_rounds = NULL,
callbacks = list(),
reset_data = FALSE,
serializable = TRUE,
...) {

# validate inputs early to avoid unnecessary computation
Expand Down Expand Up @@ -395,6 +396,10 @@ lgb.train <- function(params = list(),

}

if (serializable) {
booster$save_raw()
}

return(booster)

}
24 changes: 23 additions & 1 deletion R-package/R/lightgbm.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
#' @param params a list of parameters. See \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html}{
#' the "Parameters" section of the documentation} for a list of parameters and valid values.
#' @param verbose verbosity for output, if <= 0, also will disable the print of evaluation during training
#' @param serializable whether to make the resulting objects serializable through functions such as
#' \code{save} or \code{saveRDS} (see section "Model serialization").
#' @section Early Stopping:
#'
#' "early stopping" refers to stopping the training process if the model's performance on a given
Expand All @@ -66,6 +68,21 @@
#' in \code{params}, that metric will be considered the "first" one. If you omit \code{metric},
#' a default metric will be used based on your choice for the parameter \code{obj} (keyword argument)
#' or \code{objective} (passed into \code{params}).
#' @section Model serialization:
#'
#' LightGBM model objects can be serialized and de-serialized through functions such as \code{save}
#' or \code{saveRDS}, but similarly to libraries such as 'xgboost', serialization works a bit differently
#' from typical R objects. In order to make models serializable in R, a copy of the underlying C++ object
#' as serialized raw bytes is produced and stored in the R model object, and when this R object is
#' de-serialized, the underlying C++ model object gets reconstructed from these raw bytes, but will only
#' do so once some function that uses it is called, such as \code{predict}. In order to forcibly
#' reconstruct the C++ object after deserialization (e.g. after calling \code{readRDS} or similar), one
#' can use the function \link{lgb.restore_handle} (for example, if one makes predictions in parallel or in
#' forked processes, it will be faster to restore the handle beforehand).
#'
#' Producing and keeping these raw bytes however uses extra memory, and if they are not required,
#' it is possible to avoid producing them by passing `serializable=FALSE`. In such cases, these raw
#' bytes can be added to the model on demand through function \link{lgb.make_serializable}.
#' @keywords internal
NULL

Expand All @@ -76,6 +93,7 @@ NULL
#' @param label Vector of labels, used if \code{data} is not an \code{\link{lgb.Dataset}}
#' @param weight vector of response values. If not NULL, will set to dataset
#' @param save_name File name to use when writing the trained model to disk. Should end in ".model".
#' If passing `NULL`, will not save the trained model to disk.
#' @param ... Additional arguments passed to \code{\link{lgb.train}}. For example
#' \itemize{
#' \item{\code{valids}: a list of \code{lgb.Dataset} objects, used for validation}
Expand Down Expand Up @@ -113,6 +131,7 @@ lightgbm <- function(data,
save_name = "lightgbm.model",
init_model = NULL,
callbacks = list(),
serializable = TRUE,
...) {

# validate inputs early to avoid unnecessary computation
Expand All @@ -137,6 +156,7 @@ lightgbm <- function(data,
, "early_stopping_rounds" = early_stopping_rounds
, "init_model" = init_model
, "callbacks" = callbacks
, "serializable" = serializable
)
train_args <- append(train_args, list(...))

Expand All @@ -156,7 +176,9 @@ lightgbm <- function(data,
)

# Store model under a specific name
bst$save_model(filename = save_name)
if (!is.null(save_name)) {
bst$save_model(filename = save_name)
}

return(bst)
}
Expand Down
Loading

0 comments on commit 12b1527

Please sign in to comment.