From fe7e9d397fdd31b14af493842ef10b4527200e8d Mon Sep 17 00:00:00 2001 From: David Cortes Date: Sun, 23 Jan 2022 13:17:28 -0300 Subject: [PATCH 1/5] add index1 option to prediction outputs --- R-package/R/lgb.Booster.R | 32 +++++++++++++++++----------- R-package/man/predict.lgb.Booster.Rd | 6 ++++++ 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/R-package/R/lgb.Booster.R b/R-package/R/lgb.Booster.R index 311d3f2b910c..7db6a927cc37 100644 --- a/R-package/R/lgb.Booster.R +++ b/R-package/R/lgb.Booster.R @@ -753,6 +753,10 @@ Booster <- R6::R6Class( #' \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#predict-parameters}{ #' the "Predict Parameters" section of the documentation} for a list of parameters and #' valid values. +#' @param index1 When producing outputs that correspond to some numeration (such as +#' leaf indices), whether to make these outputs have a numeration +#' starting at 1 or at zero. Note that the underlying lightgbm core library uses zero-based +#' numeration, thus `index1=FALSE` will be slightly faster. #' @param ... ignored #' @return For regression or binary classification, it returns a vector of length \code{nrows(data)}. #' For multiclass classification, either a \code{num_class * nrows(data)} vector or @@ -806,6 +810,7 @@ predict.lgb.Booster <- function(object, header = FALSE, reshape = FALSE, params = list(), + index1 = TRUE, ...) { if (!lgb.is.Booster(x = object)) { @@ -821,19 +826,22 @@ predict.lgb.Booster <- function(object, )) } - return( - object$predict( - data = data - , start_iteration = start_iteration - , num_iteration = num_iteration - , rawscore = rawscore - , predleaf = predleaf - , predcontrib = predcontrib - , header = header - , reshape = reshape - , params = params - ) + pred <- object$predict( + data = data + , start_iteration = start_iteration + , num_iteration = num_iteration + , rawscore = rawscore + , predleaf = predleaf + , predcontrib = predcontrib + , header = header + , reshape = reshape + , params = params ) + + if (predleaf && index1) { + pred <- pred + 1L + } + return(pred) } #' @name print.lgb.Booster diff --git a/R-package/man/predict.lgb.Booster.Rd b/R-package/man/predict.lgb.Booster.Rd index 8948a4b17d01..7edf02a9a7f9 100644 --- a/R-package/man/predict.lgb.Booster.Rd +++ b/R-package/man/predict.lgb.Booster.Rd @@ -15,6 +15,7 @@ header = FALSE, reshape = FALSE, params = list(), + index1 = TRUE, ... ) } @@ -52,6 +53,11 @@ prediction outputs per case.} the "Predict Parameters" section of the documentation} for a list of parameters and valid values.} +\item{index1}{When producing outputs that correspond to some numeration (such as +leaf indices), whether to make these outputs have a numeration +starting at 1 or at zero. Note that the underlying lightgbm core library uses zero-based +numeration, thus `index1=FALSE` will be slightly faster.} + \item{...}{ignored} } \value{ From a64e7c9749a85025ad8fe4d0197d748ccfb25b86 Mon Sep 17 00:00:00 2001 From: David Cortes Date: Tue, 25 Jan 2022 19:38:20 -0300 Subject: [PATCH 2/5] apply base-1 also to iteration numbers --- R-package/R/lgb.Booster.R | 23 +++++++++++++++-------- R-package/man/predict.lgb.Booster.Rd | 21 ++++++++++++--------- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/R-package/R/lgb.Booster.R b/R-package/R/lgb.Booster.R index 7db6a927cc37..faa06f98e680 100644 --- a/R-package/R/lgb.Booster.R +++ b/R-package/R/lgb.Booster.R @@ -733,12 +733,15 @@ Booster <- R6::R6Class( #' @param object Object of class \code{lgb.Booster} #' @param data a \code{matrix} object, a \code{dgCMatrix} object or #' a character representing a path to a text file (CSV, TSV, or LibSVM) -#' @param start_iteration int or None, optional (default=None) +#' @param start_iteration int or `NULL`, optional (default=`NULL`) #' Start index of the iteration to predict. -#' If None or <= 0, starts from the first iteration. -#' @param num_iteration int or None, optional (default=None) +#' If `NULL` or <= 0, starts from the first iteration. +#' +#' If using `index1=FALSE`, it will be assumed that the numeration starts +#' at zero (e.g. passing '2' will mean starting from the 3rd round). +#' @param num_iteration int or `NULL`, optional (default=`NULL`) #' Limit number of iterations in the prediction. -#' If None, if the best iteration exists and start_iteration is None or <= 0, the +#' If `NULL`, if the best iteration exists and start_iteration is `NULL` or <= 0, the #' best iteration is used; otherwise, all iterations from start_iteration are used. #' If <= 0, all iterations from start_iteration are used (no limits). #' @param rawscore whether the prediction should be returned in the for of original untransformed @@ -753,10 +756,10 @@ Booster <- R6::R6Class( #' \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#predict-parameters}{ #' the "Predict Parameters" section of the documentation} for a list of parameters and #' valid values. -#' @param index1 When producing outputs that correspond to some numeration (such as -#' leaf indices), whether to make these outputs have a numeration -#' starting at 1 or at zero. Note that the underlying lightgbm core library uses zero-based -#' numeration, thus `index1=FALSE` will be slightly faster. +#' @param index1 When passing argument `start_iteration` and/or when producing outputs that correspond +#' to some numeration (such as leaf indices), whether to take these inputs as and/or make +#' these outputs have a numeration starting at 1 or at 0. Note that the underlying lightgbm +#' core library uses zero-based numeration, thus `index1=FALSE` will be slightly faster. #' @param ... ignored #' @return For regression or binary classification, it returns a vector of length \code{nrows(data)}. #' For multiclass classification, either a \code{num_class * nrows(data)} vector or @@ -826,6 +829,10 @@ predict.lgb.Booster <- function(object, )) } + if (!is.null(start_iteration) && start_iteration > 0 && index1) { + start_iteration <- start_iteration - 1L + } + pred <- object$predict( data = data , start_iteration = start_iteration diff --git a/R-package/man/predict.lgb.Booster.Rd b/R-package/man/predict.lgb.Booster.Rd index 7edf02a9a7f9..5f6242efae93 100644 --- a/R-package/man/predict.lgb.Booster.Rd +++ b/R-package/man/predict.lgb.Booster.Rd @@ -25,13 +25,16 @@ \item{data}{a \code{matrix} object, a \code{dgCMatrix} object or a character representing a path to a text file (CSV, TSV, or LibSVM)} -\item{start_iteration}{int or None, optional (default=None) -Start index of the iteration to predict. -If None or <= 0, starts from the first iteration.} +\item{start_iteration}{int or `NULL`, optional (default=`NULL`) + Start index of the iteration to predict. + If `NULL` or <= 0, starts from the first iteration. -\item{num_iteration}{int or None, optional (default=None) + If using `index1=FALSE`, it will be assumed that the numeration starts + at zero (e.g. passing '2' will mean starting from the 3rd round).} + +\item{num_iteration}{int or `NULL`, optional (default=`NULL`) Limit number of iterations in the prediction. -If None, if the best iteration exists and start_iteration is None or <= 0, the +If `NULL`, if the best iteration exists and start_iteration is `NULL` or <= 0, the best iteration is used; otherwise, all iterations from start_iteration are used. If <= 0, all iterations from start_iteration are used (no limits).} @@ -53,10 +56,10 @@ prediction outputs per case.} the "Predict Parameters" section of the documentation} for a list of parameters and valid values.} -\item{index1}{When producing outputs that correspond to some numeration (such as -leaf indices), whether to make these outputs have a numeration -starting at 1 or at zero. Note that the underlying lightgbm core library uses zero-based -numeration, thus `index1=FALSE` will be slightly faster.} +\item{index1}{When passing argument `start_iteration` and/or when producing outputs that correspond +to some numeration (such as leaf indices), whether to take these inputs as and/or make +these outputs have a numeration starting at 1 or at 0. Note that the underlying lightgbm +core library uses zero-based numeration, thus `index1=FALSE` will be slightly faster.} \item{...}{ignored} } From 16f5091557400856ce3cf1c61f443880821b3093 Mon Sep 17 00:00:00 2001 From: David Cortes Date: Thu, 27 Jan 2022 00:46:46 -0300 Subject: [PATCH 3/5] fix linter and tests --- R-package/R/lgb.Booster.R | 2 +- R-package/tests/testthat/test_Predictor.R | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/R-package/R/lgb.Booster.R b/R-package/R/lgb.Booster.R index faa06f98e680..8fe64a1816e2 100644 --- a/R-package/R/lgb.Booster.R +++ b/R-package/R/lgb.Booster.R @@ -829,7 +829,7 @@ predict.lgb.Booster <- function(object, )) } - if (!is.null(start_iteration) && start_iteration > 0 && index1) { + if (!is.null(start_iteration) && start_iteration > 0L && index1) { start_iteration <- start_iteration - 1L } diff --git a/R-package/tests/testthat/test_Predictor.R b/R-package/tests/testthat/test_Predictor.R index 5a1927e4e512..ba71cd3b6317 100644 --- a/R-package/tests/testthat/test_Predictor.R +++ b/R-package/tests/testthat/test_Predictor.R @@ -95,6 +95,7 @@ test_that("start_iteration works correctly", { , start_iteration = start_iter , num_iteration = n_iter , rawscore = TRUE + , index1 = FALSE ) inc_pred_contrib <- bst$predict(test$data , start_iteration = start_iter @@ -108,6 +109,13 @@ test_that("start_iteration works correctly", { expect_equal(pred_contrib2, pred_contrib1) pred_leaf1 <- predict(bst, test$data, predleaf = TRUE) - pred_leaf2 <- predict(bst, test$data, start_iteration = 0L, num_iteration = end_iter + 1L, predleaf = TRUE) + pred_leaf2 <- predict( + bst + , test$data + , start_iteration = 0L + , num_iteration = end_iter + 1L + , predleaf = TRUE + , index1 = FALSE + ) expect_equal(pred_leaf1, pred_leaf2) }) From d37b4d3135ed476a3e34e6f680168a04d8f8f6ba Mon Sep 17 00:00:00 2001 From: David Cortes Date: Thu, 27 Jan 2022 15:40:14 -0300 Subject: [PATCH 4/5] fix failing test --- R-package/tests/testthat/test_Predictor.R | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/R-package/tests/testthat/test_Predictor.R b/R-package/tests/testthat/test_Predictor.R index ba71cd3b6317..b28b9ab5ceee 100644 --- a/R-package/tests/testthat/test_Predictor.R +++ b/R-package/tests/testthat/test_Predictor.R @@ -108,7 +108,12 @@ test_that("start_iteration works correctly", { expect_equal(pred2, pred1) expect_equal(pred_contrib2, pred_contrib1) - pred_leaf1 <- predict(bst, test$data, predleaf = TRUE) + pred_leaf1 <- predict( + bst + , test$data + , predleaf = TRUE + , index1 = FALSE + ) pred_leaf2 <- predict( bst , test$data From 0c089501da7bdebcd5fda6f761bb2beaa8265e02 Mon Sep 17 00:00:00 2001 From: David Cortes Date: Thu, 27 Jan 2022 15:41:12 -0300 Subject: [PATCH 5/5] avoid unnecessary cast --- R-package/R/lgb.Booster.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R-package/R/lgb.Booster.R b/R-package/R/lgb.Booster.R index 8fe64a1816e2..b4a5dd5d83ec 100644 --- a/R-package/R/lgb.Booster.R +++ b/R-package/R/lgb.Booster.R @@ -846,7 +846,7 @@ predict.lgb.Booster <- function(object, ) if (predleaf && index1) { - pred <- pred + 1L + pred <- pred + 1.0 } return(pred) }