diff --git a/NEWS.md b/NEWS.md index b9d5ab954..436703c8e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -20,6 +20,7 @@ * `logistic_reg()` will now warn at `fit()` when the outcome has more than two levels (#545). +* Functions now indicate what class the outcome was if the outcome is the wrong class (#887). # parsnip 1.0.4 diff --git a/R/misc.R b/R/misc.R index 32225f34c..04937d1ca 100644 --- a/R/misc.R +++ b/R/misc.R @@ -336,14 +336,22 @@ check_outcome <- function(y, spec) { if (spec$mode == "regression") { outcome_is_numeric <- if (is.atomic(y)) {is.numeric(y)} else {all(map_lgl(y, is.numeric))} if (!outcome_is_numeric) { - rlang::abort("For a regression model, the outcome should be numeric.") + cls <- class(y)[[1]] + abort(paste0( + "For a regression model, the outcome should be `numeric`, ", + "not a `", cls, "`." + )) } } if (spec$mode == "classification") { outcome_is_factor <- if (is.atomic(y)) {is.factor(y)} else {all(map_lgl(y, is.factor))} if (!outcome_is_factor) { - rlang::abort("For a classification model, the outcome should be a factor.") + cls <- class(y)[[1]] + abort(paste0( + "For a classification model, the outcome should be a `factor`, ", + "not a `", cls, "`." + )) } if (inherits(spec, "logistic_reg") && is.atomic(y) && length(levels(y)) > 2) { @@ -361,7 +369,11 @@ check_outcome <- function(y, spec) { if (spec$mode == "censored regression") { outcome_is_surv <- inherits(y, "Surv") if (!outcome_is_surv) { - rlang::abort("For a censored regression model, the outcome should be a `Surv` object.") + cls <- class(y)[[1]] + abort(paste0( + "For a censored regression model, the outcome should be a `Surv` object, ", + "not a `", cls, "`." + )) } } diff --git a/tests/testthat/_snaps/misc.md b/tests/testthat/_snaps/misc.md index aa138541d..79f6fda71 100644 --- a/tests/testthat/_snaps/misc.md +++ b/tests/testthat/_snaps/misc.md @@ -132,3 +132,27 @@ Error in `fn()`: ! Please use `new_data` instead of `newdata`. +# check_outcome works as expected + + Code + check_outcome(factor(1:2), reg_spec) + Condition + Error in `check_outcome()`: + ! For a regression model, the outcome should be `numeric`, not a `factor`. + +--- + + Code + check_outcome(1:2, class_spec) + Condition + Error in `check_outcome()`: + ! For a classification model, the outcome should be a `factor`, not a `integer`. + +--- + + Code + check_outcome(1:2, cens_spec) + Condition + Error in `check_outcome()`: + ! For a censored regression model, the outcome should be a `Surv` object, not a `integer`. + diff --git a/tests/testthat/test_misc.R b/tests/testthat/test_misc.R index e1059886e..f18631472 100644 --- a/tests/testthat/test_misc.R +++ b/tests/testthat/test_misc.R @@ -185,6 +185,51 @@ test_that('set_engine works as a generic', { test_that('check_for_newdata points out correct context', { fn <- function(...) {check_for_newdata(...); invisible()} expect_snapshot(error = TRUE, - fn(newdata = "boop!") + fn(newdata = "boop!") + ) +}) + +test_that('check_outcome works as expected', { + reg_spec <- linear_reg() + + expect_no_error( + check_outcome(1:2, reg_spec) + ) + + expect_no_error( + check_outcome(mtcars, reg_spec) + ) + + expect_snapshot( + error = TRUE, + check_outcome(factor(1:2), reg_spec) + ) + + class_spec <- logistic_reg() + + expect_no_error( + check_outcome(factor(1:2), class_spec) + ) + + expect_no_error( + check_outcome(lapply(mtcars, as.factor), class_spec) + ) + + expect_snapshot( + error = TRUE, + check_outcome(1:2, class_spec) + ) + + # Fake specification to avoid having to load {censored} + cens_spec <- logistic_reg() + cens_spec$mode <- "censored regression" + + expect_no_error( + check_outcome(survival::Surv(1, 1), cens_spec) + ) + + expect_snapshot( + error = TRUE, + check_outcome(1:2, cens_spec) ) }) diff --git a/tests/testthat/test_nearest_neighbor_kknn.R b/tests/testthat/test_nearest_neighbor_kknn.R index 002fb7d85..00fb5356e 100644 --- a/tests/testthat/test_nearest_neighbor_kknn.R +++ b/tests/testthat/test_nearest_neighbor_kknn.R @@ -24,7 +24,7 @@ test_that('kknn execution', { x = hpc[, num_pred], y = hpc$input_fields ), - regexp = "outcome should be a factor" + regexp = "outcome should be a `factor`" ) # nominal