diff --git a/NEWS.md b/NEWS.md index 4e81ab441..da2a95691 100644 --- a/NEWS.md +++ b/NEWS.md @@ -4,6 +4,8 @@ * `fit()` and `fit_xy()` can now take sparse tibbles as data values (#1165). +* `predict()` can now take dgCMatrix and sparse tibble input for `new_data` argument, and error informatively when model doesn't support it (#1167). + * Transitioned package errors and warnings to use cli (#1147 and #1148 by @shum461, #1153 by @RobLBaker and @wright13, #1154 by @JamesHWade, #1160, #1161, #1081). diff --git a/R/fit.R b/R/fit.R index 52088a6b5..7be77c3ca 100644 --- a/R/fit.R +++ b/R/fit.R @@ -444,6 +444,10 @@ check_xy_interface <- function(x, y, cl, model) { } allow_sparse <- function(x) { + if (inherits(x, "model_fit")) { + x <- x$spec + } + res <- get_from_env(paste0(class(x)[1], "_encoding")) all(res$allow_sparse_x[res$engine == x$engine]) } diff --git a/R/predict.R b/R/predict.R index 285d17377..e92092aeb 100644 --- a/R/predict.R +++ b/R/predict.R @@ -160,6 +160,8 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...) } check_pred_type_dots(object, type, ...) + new_data <- to_sparse_data_frame(new_data, object) + res <- switch( type, numeric = predict_numeric(object = object, new_data = new_data, ...), diff --git a/R/sparsevctrs.R b/R/sparsevctrs.R index 9b1452ad2..5fe3633ae 100644 --- a/R/sparsevctrs.R +++ b/R/sparsevctrs.R @@ -3,6 +3,10 @@ to_sparse_data_frame <- function(x, object) { if (allow_sparse(object)) { x <- sparsevctrs::coerce_to_sparse_data_frame(x) } else { + if (inherits(object, "model_fit")) { + object <- object$spec + } + cli::cli_abort( "{.arg x} is a sparse matrix, but {.fn {class(object)[1]}} with engine {.code {object$engine}} doesn't accept that.") @@ -19,6 +23,10 @@ is_sparse_tibble <- function(x) { materialize_sparse_tibble <- function(x, object, input) { if (is_sparse_tibble(x) && (!allow_sparse(object))) { + if (inherits(object, "model_fit")) { + object <- object$spec + } + cli::cli_warn( "{.arg {input}} is a sparse tibble, but {.fn {class(object)[1]}} with engine {.code {object$engine}} doesn't accept that. Converting to diff --git a/tests/testthat/_snaps/sparsevctrs.md b/tests/testthat/_snaps/sparsevctrs.md index 3a849fbab..7eb9d3a55 100644 --- a/tests/testthat/_snaps/sparsevctrs.md +++ b/tests/testthat/_snaps/sparsevctrs.md @@ -22,6 +22,22 @@ Error in `to_sparse_data_frame()`: ! `x` is a sparse matrix, but `linear_reg()` with engine `lm` doesn't accept that. +# sparse tibble can be passed to `predict() + + Code + preds <- predict(lm_fit, sparse_mtcars) + Condition + Warning: + `x` is a sparse tibble, but `linear_reg()` with engine `lm` doesn't accept that. Converting to non-sparse. + +# sparse matrices can be passed to `predict() + + Code + predict(lm_fit, sparse_mtcars) + Condition + Error in `to_sparse_data_frame()`: + ! `x` is a sparse matrix, but `linear_reg()` with engine `lm` doesn't accept that. + # to_sparse_data_frame() is used correctly Code diff --git a/tests/testthat/test-sparsevctrs.R b/tests/testthat/test-sparsevctrs.R index 4eb876ae8..aa452f2e3 100644 --- a/tests/testthat/test-sparsevctrs.R +++ b/tests/testthat/test-sparsevctrs.R @@ -67,6 +67,66 @@ test_that("sparse matrices can be passed to `fit_xy()", { ) }) +test_that("sparse tibble can be passed to `predict()", { + skip_if_not_installed("ranger") + + hotel_data <- sparse_hotel_rates() + hotel_data <- sparsevctrs::coerce_to_sparse_tibble(hotel_data) + + spec <- rand_forest(trees = 10) %>% + set_mode("regression") %>% + set_engine("ranger") + + tree_fit <- fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1]) + + expect_no_error( + predict(tree_fit, hotel_data) + ) + + spec <- linear_reg() %>% + set_mode("regression") %>% + set_engine("lm") + + lm_fit <- fit(spec, mpg ~ ., data = mtcars) + + sparse_mtcars <- mtcars %>% + sparsevctrs::coerce_to_sparse_matrix() %>% + sparsevctrs::coerce_to_sparse_tibble() + + expect_snapshot( + preds <- predict(lm_fit, sparse_mtcars) + ) +}) + +test_that("sparse matrices can be passed to `predict()", { + skip_if_not_installed("ranger") + + hotel_data <- sparse_hotel_rates() + + spec <- rand_forest(trees = 10) %>% + set_mode("regression") %>% + set_engine("ranger") + + tree_fit <- fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1]) + + expect_no_error( + predict(tree_fit, hotel_data) + ) + + spec <- linear_reg() %>% + set_mode("regression") %>% + set_engine("lm") + + lm_fit <- fit(spec, mpg ~ ., data = mtcars) + + sparse_mtcars <- sparsevctrs::coerce_to_sparse_matrix(mtcars) + + expect_snapshot( + error = TRUE, + predict(lm_fit, sparse_mtcars) + ) +}) + test_that("to_sparse_data_frame() is used correctly", { skip_if_not_installed("xgboost")