Skip to content

Commit

Permalink
[R-package] Add print() and summary() methods for Booster (#4686)
Browse files Browse the repository at this point in the history
* add print and summary S3 method

* correct wrong signature

* attempt at bypassing linter

* Update R-package/R/lgb.Booster.R

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* Update R-package/src/lightgbm_R.h

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* Update include/LightGBM/c_api.h

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* add more tests

* linter

* don't pluralize single tree

* remove duplicated function

* update changed function name

* missing declaration

* Update lightgbm_R.h

* Update R-package/tests/testthat/test_lgb.Booster.R

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* accommodate custom objectives in print

* linter

* linter

Co-authored-by: James Lamb <jaylamb20@gmail.com>
  • Loading branch information
david-cortes and jameslamb authored Nov 13, 2021
1 parent 6e6fb14 commit 2f59773
Show file tree
Hide file tree
Showing 7 changed files with 231 additions and 0 deletions.
2 changes: 2 additions & 0 deletions R-package/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ S3method(dimnames,lgb.Dataset)
S3method(get_field,lgb.Dataset)
S3method(getinfo,lgb.Dataset)
S3method(predict,lgb.Booster)
S3method(print,lgb.Booster)
S3method(set_field,lgb.Dataset)
S3method(setinfo,lgb.Dataset)
S3method(slice,lgb.Dataset)
S3method(summary,lgb.Booster)
export(get_field)
export(getinfo)
export(lgb.Dataset)
Expand Down
59 changes: 59 additions & 0 deletions R-package/R/lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,65 @@ predict.lgb.Booster <- function(object,
)
}

#' @name print.lgb.Booster
#' @title Print method for LightGBM model
#' @description Show summary information about a LightGBM model object (same as \code{summary}).
#' @param x Object of class \code{lgb.Booster}
#' @param ... Not used
#' @return The same input `x`, returned as invisible.
#' @export
print.lgb.Booster <- function(x, ...) {
# nolint start
handle <- x$.__enclos_env__$private$handle
handle_is_null <- lgb.is.null.handle(handle)

if (!handle_is_null) {
ntrees <- x$current_iter()
if (ntrees == 1L) {
cat("LightGBM Model (1 tree)\n")
} else {
cat(sprintf("LightGBM Model (%d trees)\n", ntrees))
}
} else {
cat("LightGBM Model\n")
}

if (!handle_is_null) {
obj <- x$params$objective
if (obj == "none") {
obj <- "custom"
}
if (x$.__enclos_env__$private$num_class == 1L) {
cat(sprintf("Objective: %s\n", obj))
} else {
cat(sprintf("Objective: %s (%d classes)\n"
, obj
, x$.__enclos_env__$private$num_class))
}
} else {
cat("(Booster handle is invalid)\n")
}

if (!handle_is_null) {
ncols <- .Call(LGBM_BoosterGetNumFeature_R, handle)
cat(sprintf("Fitted to dataset with %d columns\n", ncols))
}
# nolint end

return(invisible(x))
}

#' @name summary.lgb.Booster
#' @title Summary method for LightGBM model
#' @description Show summary information about a LightGBM model object (same as \code{print}).
#' @param object Object of class \code{lgb.Booster}
#' @param ... Not used
#' @return The same input `object`, returned as invisible.
#' @export
summary.lgb.Booster <- function(object, ...) {
print(object)
}

#' @name lgb.load
#' @title Load LightGBM model
#' @description Load LightGBM takes in either a file path or model string.
Expand Down
19 changes: 19 additions & 0 deletions R-package/man/print.lgb.Booster.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 19 additions & 0 deletions R-package/man/summary.lgb.Booster.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions R-package/src/lightgbm_R.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,15 @@ SEXP LGBM_BoosterGetNumClasses_R(SEXP handle,
R_API_END();
}

SEXP LGBM_BoosterGetNumFeature_R(SEXP handle) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int out = 0;
CHECK_CALL(LGBM_BoosterGetNumFeature(R_ExternalPtrAddr(handle), &out));
return Rf_ScalarInteger(out);
R_API_END();
}

SEXP LGBM_BoosterUpdateOneIter_R(SEXP handle) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
Expand Down Expand Up @@ -889,6 +898,7 @@ static const R_CallMethodDef CallEntries[] = {
{"LGBM_BoosterResetTrainingData_R" , (DL_FUNC) &LGBM_BoosterResetTrainingData_R , 2},
{"LGBM_BoosterResetParameter_R" , (DL_FUNC) &LGBM_BoosterResetParameter_R , 2},
{"LGBM_BoosterGetNumClasses_R" , (DL_FUNC) &LGBM_BoosterGetNumClasses_R , 2},
{"LGBM_BoosterGetNumFeature_R" , (DL_FUNC) &LGBM_BoosterGetNumFeature_R , 1},
{"LGBM_BoosterUpdateOneIter_R" , (DL_FUNC) &LGBM_BoosterUpdateOneIter_R , 1},
{"LGBM_BoosterUpdateOneIterCustom_R", (DL_FUNC) &LGBM_BoosterUpdateOneIterCustom_R, 4},
{"LGBM_BoosterRollbackOneIter_R" , (DL_FUNC) &LGBM_BoosterRollbackOneIter_R , 1},
Expand Down
9 changes: 9 additions & 0 deletions R-package/src/lightgbm_R.h
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,15 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetNumClasses_R(
SEXP out
);

/*!
* \brief Get number of features.
* \param handle Booster handle
* \return Total number of features, as R integer
*/
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetNumFeature_R(
SEXP handle
);

/*!
* \brief update the model in one round
* \param handle Booster handle
Expand Down
113 changes: 113 additions & 0 deletions R-package/tests/testthat/test_lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -1041,3 +1041,116 @@ test_that("boosters with linear models at leaves can be written to RDS and re-lo
preds2 <- predict(bst2, X)
expect_identical(preds, preds2)
})

test_that("Booster's print, show, and summary work correctly", {
.have_same_handle <- function(model, other_model) {
expect_equal(
model$.__enclos_env__$private$handle
, other_model$.__enclos_env__$private$handle
)
}

.check_methods_work <- function(model) {

# should work for fitted models
ret <- print(model)
.have_same_handle(ret, model)
ret <- show(model)
expect_null(ret)
ret <- summary(model)
.have_same_handle(ret, model)

# should not fail for finalized models
model$finalize()
ret <- print(model)
.have_same_handle(ret, model)
ret <- show(model)
expect_null(ret)
ret <- summary(model)
.have_same_handle(ret, model)
}

data("mtcars")
model <- lgb.train(
params = list(objective = "regression")
, data = lgb.Dataset(
as.matrix(mtcars[, -1L])
, label = mtcars$mpg)
, verbose = 0L
, nrounds = 5L
)
.check_methods_work(model)

data("iris")
model <- lgb.train(
params = list(objective = "multiclass", num_class = 3L)
, data = lgb.Dataset(
as.matrix(iris[, -5L])
, label = as.numeric(factor(iris$Species)) - 1.0
)
, verbose = 0L
, nrounds = 5L
)
.check_methods_work(model)


# with custom objective
.logregobj <- function(preds, dtrain) {
labels <- get_field(dtrain, "label")
preds <- 1.0 / (1.0 + exp(-preds))
grad <- preds - labels
hess <- preds * (1.0 - preds)
return(list(grad = grad, hess = hess))
}

.evalerror <- function(preds, dtrain) {
labels <- get_field(dtrain, "label")
preds <- 1.0 / (1.0 + exp(-preds))
err <- as.numeric(sum(labels != (preds > 0.5))) / length(labels)
return(list(
name = "error"
, value = err
, higher_better = FALSE
))
}

model <- lgb.train(
data = lgb.Dataset(
as.matrix(iris[, -5L])
, label = as.numeric(iris$Species == "virginica")
)
, obj = .logregobj
, eval = .evalerror
, verbose = 0L
, nrounds = 5L
)

.check_methods_work(model)
})

test_that("LGBM_BoosterGetNumFeature_R returns correct outputs", {
data("mtcars")
model <- lgb.train(
params = list(objective = "regression")
, data = lgb.Dataset(
as.matrix(mtcars[, -1L])
, label = mtcars$mpg)
, verbose = 0L
, nrounds = 5L
)
ncols <- .Call(LGBM_BoosterGetNumFeature_R, model$.__enclos_env__$private$handle)
expect_equal(ncols, ncol(mtcars) - 1L)

data("iris")
model <- lgb.train(
params = list(objective = "multiclass", num_class = 3L)
, data = lgb.Dataset(
as.matrix(iris[, -5L])
, label = as.numeric(factor(iris$Species)) - 1.0
)
, verbose = 0L
, nrounds = 5L
)
ncols <- .Call(LGBM_BoosterGetNumFeature_R, model$.__enclos_env__$private$handle)
expect_equal(ncols, ncol(iris) - 1L)
})

0 comments on commit 2f59773

Please sign in to comment.