Skip to content

Commit

Permalink
[R] Basic implementation for R for JSON serialization.
Browse files Browse the repository at this point in the history
* Change `xgb.save.raw' into full serialization instead of simple model.
* Add `xgb.load.raw' for unserialization.
* Force renew.
  • Loading branch information
trivialfis committed Apr 1, 2020
1 parent 6601a64 commit 5b76aea
Show file tree
Hide file tree
Showing 11 changed files with 166 additions and 55 deletions.
4 changes: 4 additions & 0 deletions R-package/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,7 @@ set_target_properties(
set(XGBOOST_DEFINITIONS "${XGBOOST_DEFINITIONS};${R_DEFINITIONS}" PARENT_SCOPE)
set(XGBOOST_OBJ_SOURCES $<TARGET_OBJECTS:xgboost-r> PARENT_SCOPE)
set(LINKED_LIBRARIES_PRIVATE ${LINKED_LIBRARIES_PRIVATE} ${LIBR_CORE_LIBRARY} PARENT_SCOPE)

if (USE_OPENMP)
target_link_libraries(xgboost-r PRIVATE OpenMP::OpenMP_CXX)
endif ()
5 changes: 5 additions & 0 deletions R-package/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ export("xgb.attr<-")
export("xgb.attributes<-")
export("xgb.config<-")
export("xgb.parameters<-")
export("xgb.config<-")
export(cb.cv.predict)
export(cb.early.stop)
export(cb.evaluation.log)
Expand All @@ -30,6 +31,7 @@ export(xgb.Booster.complete)
export(xgb.DMatrix)
export(xgb.DMatrix.save)
export(xgb.attr)
export(xgb.config)
export(xgb.attributes)
export(xgb.config)
export(xgb.create.features)
Expand All @@ -40,6 +42,8 @@ export(xgb.ggplot.deepness)
export(xgb.ggplot.importance)
export(xgb.importance)
export(xgb.load)
export(xgb.load.raw)
export(xgb.unserialize)
export(xgb.model.dt.tree)
export(xgb.plot.deepness)
export(xgb.plot.importance)
Expand All @@ -48,6 +52,7 @@ export(xgb.plot.shap)
export(xgb.plot.tree)
export(xgb.save)
export(xgb.save.raw)
export(xgb.serialize)
export(xgb.train)
export(xgboost)
import(methods)
Expand Down
38 changes: 27 additions & 11 deletions R-package/R/xgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,34 @@ xgb.Booster.handle <- function(params = list(), cachelist = list(), modelfile =
!all(vapply(cachelist, inherits, logical(1), what = 'xgb.DMatrix'))) {
stop("cachelist must be a list of xgb.DMatrix objects")
}

handle <- .Call(XGBoosterCreate_R, cachelist)
## Load existing model, dispatch for on disk model file and in memory buffer
if (!is.null(modelfile)) {
if (typeof(modelfile) == "character") {
## A filename
handle <- .Call(XGBoosterCreate_R, cachelist)
.Call(XGBoosterLoadModel_R, handle, modelfile[1])
class(handle) <- "xgb.Booster.handle"
if (length(params) > 0) {
xgb.parameters(handle) <- params
}
return(handle)
} else if (typeof(modelfile) == "raw") {
.Call(XGBoosterLoadModelFromRaw_R, handle, modelfile)
## A memory buffer
bst <- xgb.unserialize(modelfile)
xgb.parameters(bst) <- params
return (bst)
} else if (inherits(modelfile, "xgb.Booster")) {
## A booster object
bst <- xgb.Booster.complete(modelfile, saveraw = TRUE)
.Call(XGBoosterLoadModelFromRaw_R, handle, bst$raw)
bst <- xgb.unserialize(bst$raw)
xgb.parameters(bst) <- params
return (bst)
} else {
stop("modelfile must be either character filename, or raw booster dump, or xgb.Booster object")
}
}
## Create new model
handle <- .Call(XGBoosterCreate_R, cachelist)
class(handle) <- "xgb.Booster.handle"
if (length(params) > 0) {
xgb.parameters(handle) <- params
Expand Down Expand Up @@ -113,8 +127,9 @@ xgb.Booster.complete <- function(object, saveraw = TRUE) {
if (is.null.handle(object$handle)) {
object$handle <- xgb.Booster.handle(modelfile = object$raw)
} else {
if (is.null(object$raw) && saveraw)
object$raw <- xgb.save.raw(object$handle)
if (is.null(object$raw) && saveraw) {
object$raw <- xgb.serialize(object$handle)
}
}
return(object)
}
Expand Down Expand Up @@ -399,7 +414,7 @@ predict.xgb.Booster.handle <- function(object, ...) {
#' That would only matter if attributes need to be set many times.
#' Note, however, that when feeding a handle of an \code{xgb.Booster} object to the attribute setters,
#' the raw model cache of an \code{xgb.Booster} object would not be automatically updated,
#' and it would be user's responsibility to call \code{xgb.save.raw} to update it.
#' and it would be user's responsibility to call \code{xgb.serialize} to update it.
#'
#' The \code{xgb.attributes<-} setter either updates the existing or adds one or several attributes,
#' but it doesn't delete the other existing attributes.
Expand Down Expand Up @@ -458,7 +473,7 @@ xgb.attr <- function(object, name) {
}
.Call(XGBoosterSetAttr_R, handle, as.character(name[1]), value)
if (is(object, 'xgb.Booster') && !is.null(object$raw)) {
object$raw <- xgb.save.raw(object$handle)
object$raw <- xgb.serialize(object$handle)
}
object
}
Expand Down Expand Up @@ -498,7 +513,7 @@ xgb.attributes <- function(object) {
.Call(XGBoosterSetAttr_R, handle, names(a[i]), a[[i]])
}
if (is(object, 'xgb.Booster') && !is.null(object$raw)) {
object$raw <- xgb.save.raw(object$handle)
object$raw <- xgb.serialize(object$handle)
}
object
}
Expand Down Expand Up @@ -528,7 +543,8 @@ xgb.config <- function(object) {
`xgb.config<-` <- function(object, value) {
handle <- xgb.get.handle(object)
.Call(XGBoosterLoadJsonConfig_R, handle, value)
object$raw <- xgb.Booster.complete(object)
object$raw <- NULL # force renew the raw buffer
object <- xgb.Booster.complete(object)
object
}

Expand Down Expand Up @@ -568,7 +584,7 @@ xgb.config <- function(object) {
.Call(XGBoosterSetParam_R, handle, names(p[i]), p[[i]])
}
if (is(object, 'xgb.Booster') && !is.null(object$raw)) {
object$raw <- xgb.save.raw(object$handle)
object$raw <- xgb.serialize(object$handle)
}
object
}
Expand Down
14 changes: 14 additions & 0 deletions R-package/R/xgb.load.raw.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#' Load serialised xgboost model from R's raw vector
#'
#' User can generate raw memory buffer by calling xgb.save.raw
#'
#' @param buffer the buffer returned by xgb.save.raw
#'
#' @export
xgb.load.raw <- function(buffer) {
cachelist <- list()
handle <- .Call(XGBoosterCreate_R, cachelist)
.Call(XGBoosterLoadModelFromRaw_R, handle, buffer)
class(handle) <- "xgb.Booster.handle"
return (handle)
}
16 changes: 8 additions & 8 deletions R-package/R/xgb.save.raw.R
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
#' Save xgboost model to R's raw vector,
#' user can call xgb.load to load the model back from raw vector
#'
#' user can call xgb.load.raw to load the model back from raw vector
#'
#' Save xgboost model from xgboost or xgb.train
#'
#'
#' @param model the model object.
#'
#'
#' @examples
#' data(agaricus.train, package='xgboost')
#' data(agaricus.test, package='xgboost')
#' train <- agaricus.train
#' test <- agaricus.test
#' bst <- xgboost(data = train$data, label = train$label, max_depth = 2,
#' bst <- xgboost(data = train$data, label = train$label, max_depth = 2,
#' eta = 1, nthread = 2, nrounds = 2,objective = "binary:logistic")
#' raw <- xgb.save.raw(bst)
#' bst <- xgb.load(raw)
#' bst <- xgb.load.raw(raw)
#' pred <- predict(bst, test$data)
#'
#' @export
xgb.save.raw <- function(model) {
model <- xgb.get.handle(model)
.Call(XGBoosterModelToRaw_R, model)
handle <- xgb.get.handle(model)
.Call(XGBoosterModelToRaw_R, handle)
}
11 changes: 11 additions & 0 deletions R-package/R/xgb.serialize.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#' Serialize the booster instance into R's raw vector. The serialization method differs
#' from \code{\link{xgb.save.raw}} as the latter one saves only the model but not
#' parameters. The serialization format is not stable across different xgboost versions.
#'
#' @param booster the booster instance
#'
#' @export
xgb.serialize <- function(booster) {
handle <- xgb.get.handle(booster)
.Call(XGBoosterSerializeToBuffer_R, handle)
}
12 changes: 12 additions & 0 deletions R-package/R/xgb.unserialize.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#' Load the instance back from \code{\link{xgb.serialize}}
#'
#' @param buffer the buffer containing booster instance saved by \code{\link{xgb.serialize}}
#'
#' @export
xgb.unserialize <- function(buffer) {
cachelist <- list()
handle <- .Call(XGBoosterCreate_R, cachelist)
.Call(XGBoosterUnserializeFromBuffer_R, handle, buffer)
class(handle) <- "xgb.Booster.handle"
return (handle)
}
4 changes: 4 additions & 0 deletions R-package/src/init.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ extern SEXP XGBoosterLoadModelFromRaw_R(SEXP, SEXP);
extern SEXP XGBoosterLoadModel_R(SEXP, SEXP);
extern SEXP XGBoosterSaveJsonConfig_R(SEXP handle);
extern SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value);
extern SEXP XGBoosterSerializeToBuffer_R(SEXP handle);
extern SEXP XGBoosterUnserializeFromBuffer_R(SEXP handle, SEXP raw);
extern SEXP XGBoosterModelToRaw_R(SEXP);
extern SEXP XGBoosterPredict_R(SEXP, SEXP, SEXP, SEXP, SEXP);
extern SEXP XGBoosterSaveModel_R(SEXP, SEXP);
Expand Down Expand Up @@ -53,6 +55,8 @@ static const R_CallMethodDef CallEntries[] = {
{"XGBoosterLoadModel_R", (DL_FUNC) &XGBoosterLoadModel_R, 2},
{"XGBoosterSaveJsonConfig_R", (DL_FUNC) &XGBoosterSaveJsonConfig_R, 1},
{"XGBoosterLoadJsonConfig_R", (DL_FUNC) &XGBoosterLoadJsonConfig_R, 2},
{"XGBoosterSerializeToBuffer_R", (DL_FUNC) &XGBoosterSerializeToBuffer_R, 1},
{"XGBoosterUnserializeFromBuffer_R", (DL_FUNC) &XGBoosterUnserializeFromBuffer_R, 2},
{"XGBoosterModelToRaw_R", (DL_FUNC) &XGBoosterModelToRaw_R, 1},
{"XGBoosterPredict_R", (DL_FUNC) &XGBoosterPredict_R, 5},
{"XGBoosterSaveModel_R", (DL_FUNC) &XGBoosterSaveModel_R, 2},
Expand Down
42 changes: 33 additions & 9 deletions R-package/src/xgboost_R.cc
Original file line number Diff line number Diff line change
Expand Up @@ -338,15 +338,6 @@ SEXP XGBoosterSaveModel_R(SEXP handle, SEXP fname) {
return R_NilValue;
}

SEXP XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw) {
R_API_BEGIN();
CHECK_CALL(XGBoosterLoadModelFromBuffer(R_ExternalPtrAddr(handle),
RAW(raw),
length(raw)));
R_API_END();
return R_NilValue;
}

SEXP XGBoosterModelToRaw_R(SEXP handle) {
SEXP ret;
R_API_BEGIN();
Expand All @@ -362,6 +353,15 @@ SEXP XGBoosterModelToRaw_R(SEXP handle) {
return ret;
}

SEXP XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw) {
R_API_BEGIN();
CHECK_CALL(XGBoosterLoadModelFromBuffer(R_ExternalPtrAddr(handle),
RAW(raw),
length(raw)));
R_API_END();
return R_NilValue;
}

SEXP XGBoosterSaveJsonConfig_R(SEXP handle) {
const char* ret;
R_API_BEGIN();
Expand All @@ -380,6 +380,30 @@ SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value) {
return R_NilValue;
}

SEXP XGBoosterSerializeToBuffer_R(SEXP handle) {
SEXP ret;
R_API_BEGIN();
bst_ulong out_len;
const char *raw;
CHECK_CALL(XGBoosterSerializeToBuffer(R_ExternalPtrAddr(handle), &out_len, &raw));
ret = PROTECT(allocVector(RAWSXP, out_len));
if (out_len != 0) {
memcpy(RAW(ret), raw, out_len);
}
R_API_END();
UNPROTECT(1);
return ret;
}

SEXP XGBoosterUnserializeFromBuffer_R(SEXP handle, SEXP raw) {
R_API_BEGIN();
XGBoosterUnserializeFromBuffer(R_ExternalPtrAddr(handle),
RAW(raw),
length(raw));
R_API_END();
return R_NilValue;
}

SEXP XGBoosterDumpModel_R(SEXP handle, SEXP fmap, SEXP with_stats, SEXP dump_format) {
SEXP out;
R_API_BEGIN();
Expand Down
17 changes: 17 additions & 0 deletions R-package/src/xgboost_R.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ XGB_DLL SEXP XGBoosterModelToRaw_R(SEXP handle);
* \param handle handle
* \return JSON string
*/

XGB_DLL SEXP XGBoosterSaveJsonConfig_R(SEXP handle);
/*!
* \brief Load the JSON string returnd by XGBoosterSaveJsonConfig_R
Expand All @@ -195,6 +196,22 @@ XGB_DLL SEXP XGBoosterSaveJsonConfig_R(SEXP handle);
* \return R_NilValue
*/
XGB_DLL SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value);

/*!
* \brief Memory snapshot based serialization method. Saves everything states
* into buffer.
* \param handle handle to booster
*/
XGB_DLL SEXP XGBoosterSerializeToBuffer_R(SEXP handle);

/*!
* \brief Memory snapshot based serialization method. Loads the buffer returned
* from `XGBoosterSerializeToBuffer'.
* \param handle handle to booster
* \return raw byte array
*/
XGB_DLL SEXP XGBoosterUnserializeFromBuffer_R(SEXP handle, SEXP raw);

/*!
* \brief dump model into a string
* \param handle handle
Expand Down
Loading

0 comments on commit 5b76aea

Please sign in to comment.