Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R-package] Add print() and summary() methods for Booster #4686

Merged
merged 18 commits into from
Nov 13, 2021
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions R-package/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ S3method(dimnames,lgb.Dataset)
S3method(get_field,lgb.Dataset)
S3method(getinfo,lgb.Dataset)
S3method(predict,lgb.Booster)
S3method(print,lgb.Booster)
S3method(set_field,lgb.Dataset)
S3method(setinfo,lgb.Dataset)
S3method(slice,lgb.Dataset)
S3method(summary,lgb.Booster)
export(get_field)
export(getinfo)
export(lgb.Dataset)
Expand Down
55 changes: 55 additions & 0 deletions R-package/R/lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,61 @@ predict.lgb.Booster <- function(object,
)
}

#' @name print.lgb.Booster
#' @title Print method for LightGBM model
#' @description Show summary information about a LightGBM model object (same as \code{summary}).
#' @param x Object of class \code{lgb.Booster}
#' @param ... Not used
#' @return The same input `x`, returned as invisible.
#' @export
print.lgb.Booster <- function(x, ...) {
# nolint start
handle <- x$.__enclos_env__$private$handle
handle_is_null <- lgb.is.null.handle(handle)

if (!handle_is_null) {
ntrees <- x$current_iter()
if (ntrees == 1L) {
cat("LightGBM Model (1 tree)\n")
} else {
cat(sprintf("LightGBM Model (%d trees)\n", ntrees))
}
} else {
cat("LightGBM Model\n")
}

if (!handle_is_null) {
if (x$.__enclos_env__$private$num_class == 1L) {
cat(sprintf("Objective: %s\n", x$params$objective))
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
} else {
cat(sprintf("Objective: %s (%d classes)\n"
, x$params$objective
, x$.__enclos_env__$private$num_class))
}
} else {
cat("(Booster handle is invalid)\n")
}

if (!handle_is_null) {
ncols <- .Call(LGBM_BoosterGetNumFeature_R, handle)
cat(sprintf("Fitted to dataset with %d columns\n", ncols))
}
# nolint end

return(invisible(x))
}

#' @name summary.lgb.Booster
#' @title Summary method for LightGBM model
#' @description Show summary information about a LightGBM model object (same as \code{print}).
#' @param object Object of class \code{lgb.Booster}
#' @param ... Not used
#' @return The same input `object`, returned as invisible.
#' @export
summary.lgb.Booster <- function(object, ...) {
print(object)
}

#' @name lgb.load
#' @title Load LightGBM model
#' @description Load LightGBM takes in either a file path or model string.
Expand Down
19 changes: 19 additions & 0 deletions R-package/man/print.lgb.Booster.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 19 additions & 0 deletions R-package/man/summary.lgb.Booster.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions R-package/src/lightgbm_R.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,15 @@ SEXP LGBM_BoosterGetNumClasses_R(SEXP handle,
R_API_END();
}

SEXP LGBM_BoosterGetNumFeature_R(SEXP handle) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int out = 0;
CHECK_CALL(LGBM_BoosterGetNumFeature(R_ExternalPtrAddr(handle), &out));
return Rf_ScalarInteger(out);
R_API_END();
}

SEXP LGBM_BoosterUpdateOneIter_R(SEXP handle) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
Expand Down Expand Up @@ -889,6 +898,7 @@ static const R_CallMethodDef CallEntries[] = {
{"LGBM_BoosterResetTrainingData_R" , (DL_FUNC) &LGBM_BoosterResetTrainingData_R , 2},
{"LGBM_BoosterResetParameter_R" , (DL_FUNC) &LGBM_BoosterResetParameter_R , 2},
{"LGBM_BoosterGetNumClasses_R" , (DL_FUNC) &LGBM_BoosterGetNumClasses_R , 2},
{"LGBM_BoosterGetNumFeature_R" , (DL_FUNC) &LGBM_BoosterGetNumFeature_R , 1},
{"LGBM_BoosterUpdateOneIter_R" , (DL_FUNC) &LGBM_BoosterUpdateOneIter_R , 1},
{"LGBM_BoosterUpdateOneIterCustom_R", (DL_FUNC) &LGBM_BoosterUpdateOneIterCustom_R, 4},
{"LGBM_BoosterRollbackOneIter_R" , (DL_FUNC) &LGBM_BoosterRollbackOneIter_R , 1},
Expand Down
9 changes: 9 additions & 0 deletions R-package/src/lightgbm_R.h
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,15 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetNumClasses_R(
SEXP out
);

/*!
* \brief Get number of features.
* \param handle Booster handle
* \return Total number of features, as R integer
*/
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetNumFeature_R(
SEXP handle
);

/*!
* \brief update the model in one round
* \param handle Booster handle
Expand Down
62 changes: 62 additions & 0 deletions R-package/tests/testthat/test_lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -1035,3 +1035,65 @@ test_that("boosters with linear models at leaves can be written to RDS and re-lo
preds2 <- predict(bst2, X)
expect_identical(preds, preds2)
})

test_that("Booster's print, show, and summary work correctly", {
check_methods_work <- function(model) {
expect_error(print(model), NA)
expect_error(show(model), NA)
expect_error(summary(model), NA)
model$finalize()
expect_error(print(model), NA)
expect_error(show(model), NA)
expect_error(summary(model), NA)
}

data("mtcars")
model <- lgb.train(
params = list(objective = "regression")
, data = lgb.Dataset(
as.matrix(mtcars[, -1L])
, label = mtcars$mpg)
, verbose = 0L
, nrounds = 5L
)
check_methods_work(model)

data("iris")
model <- lgb.train(
params = list(objective = "multiclass", num_class = 3L)
, data = lgb.Dataset(
as.matrix(iris[, -5L])
, label = as.numeric(factor(iris$Species)) - 1.0
)
, verbose = 0L
, nrounds = 5L
)
check_methods_work(model)
david-cortes marked this conversation as resolved.
Show resolved Hide resolved
})

test_that("LGBM_BoosterGetNumFeature_R returns correct outputs", {
data("mtcars")
model <- lgb.train(
params = list(objective = "regression")
, data = lgb.Dataset(
as.matrix(mtcars[, -1L])
, label = mtcars$mpg)
, verbose = 0L
, nrounds = 5L
)
ncols <- .Call(LGBM_BoosterGetNumFeature_R, model$.__enclos_env__$private$handle)
expect_equal(ncols, ncol(mtcars) - 1L)

data("iris")
model <- lgb.train(
params = list(objective = "multiclass", num_class = 3L)
, data = lgb.Dataset(
as.matrix(iris[, -5L])
, label = as.numeric(factor(iris$Species)) - 1.0
)
, verbose = 0L
, nrounds = 5L
)
ncols <- .Call(LGBM_BoosterGetNumFeature_R, model$.__enclos_env__$private$handle)
expect_equal(ncols, ncol(iris) - 1L)
})