diff --git a/R-package/tests/testthat/test_basic.R b/R-package/tests/testthat/test_basic.R index d74739150aa4..617cff35d1c3 100644 --- a/R-package/tests/testthat/test_basic.R +++ b/R-package/tests/testthat/test_basic.R @@ -252,3 +252,46 @@ test_that("lgb.train() throws an informative error if 'valids' contains lgb.Data ) }, regexp = "each element of valids must have a name") }) + +test_that("lgb.train() works with force_col_wise and force_row_wise", { + set.seed(1234L) + nrounds <- 10L + dtrain <- lgb.Dataset( + train$data + , label = train$label + ) + params <- list( + objective = "binary" + , metric = "binary_error" + , force_col_wise = TRUE + ) + bst_colwise <- lgb.train( + params = params + , data = dtrain + , nrounds = nrounds + ) + + params <- list( + objective = "binary" + , metric = "binary_error" + , force_row_wise = TRUE + ) + bst_row_wise <- lgb.train( + params = params + , data = dtrain + , nrounds = nrounds + ) + + expected_error <- 0.003070782 + expect_equal(bst_colwise$eval_train()[[1L]][["value"]], expected_error) + expect_equal(bst_row_wise$eval_train()[[1L]][["value"]], expected_error) + + # check some basic details of the boosters just to be sure force_col_wise + # and force_row_wise are not causing any weird side effects + for (bst in list(bst_row_wise, bst_colwise)){ + expect_equal(bst$current_iter(), nrounds) + parsed_model <- jsonlite::fromJSON(bst$dump_model()) + expect_equal(parsed_model$objective, "binary sigmoid:1") + expect_false(parsed_model$average_output) + } +}) diff --git a/R-package/tests/testthat/test_learning_to_rank.R b/R-package/tests/testthat/test_learning_to_rank.R index 049ba53c78f6..65768a9ae178 100644 --- a/R-package/tests/testthat/test_learning_to_rank.R +++ b/R-package/tests/testthat/test_learning_to_rank.R @@ -47,8 +47,8 @@ test_that("learning-to-rank with lgb.train() works as expected", { } expect_identical(sapply(eval_results, function(x) {x$name}), eval_names) expect_equal(eval_results[[1L]][["value"]], 0.825) - expect_true(abs(eval_results[[2L]][["value"]] - 0.795986) < TOLERANCE) - expect_true(abs(eval_results[[3L]][["value"]] - 0.7734639) < TOLERANCE) + expect_true(abs(eval_results[[2L]][["value"]] - 0.7766434) < TOLERANCE) + expect_true(abs(eval_results[[3L]][["value"]] - 0.7527939) < TOLERANCE) }) test_that("learning-to-rank with lgb.cv() works as expected", {