diff --git a/NEWS.md b/NEWS.md index 37874d908..953e0b1cb 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,8 @@ * `fit_xy()` can now take dgCMatrix input for `x` argument (#1121). +* `fit()` and `fit_xy()` can now take sparse tibbles as data values (#1165). + * Transitioned package errors and warnings to use cli (#1147 and #1148 by @shum461, #1153 by @RobLBaker and @wright13, #1154 by @JamesHWade, #1160, #1161, #1081). diff --git a/R/arguments.R b/R/arguments.R index ad1033270..0e62a0938 100644 --- a/R/arguments.R +++ b/R/arguments.R @@ -264,6 +264,7 @@ make_xy_call <- function(object, target, env) { none = rlang::expr(x), data.frame = rlang::expr(maybe_data_frame(x)), matrix = rlang::expr(maybe_matrix(x)), + dgCMatrix = rlang::expr(maybe_sparse_matrix(x)), cli::cli_abort("Invalid data type target: {target}.") ) if (uses_weights) { diff --git a/R/convert_data.R b/R/convert_data.R index 87a20a3c0..73fea702d 100644 --- a/R/convert_data.R +++ b/R/convert_data.R @@ -41,9 +41,10 @@ indicators = "traditional", composition = "data.frame", remove_intercept = TRUE) { - if (!(composition %in% c("data.frame", "matrix"))) { + if (!(composition %in% c("data.frame", "matrix", "dgCMatrix"))) { cli::cli_abort( - "{.arg composition} should be either {.val data.frame} or {.val matrix}." + "{.arg composition} should be either {.val data.frame}, {.val matrix}, or + {.val dgCMatrix}." ) } @@ -122,6 +123,18 @@ xlevels = .getXlevels(mod_terms, mod_frame), options = options ) + } else if (composition == "dgCMatrix") { + x <- sparsevctrs::coerce_to_sparse_matrix(data) + res <- + list( + x = x, + y = y, + weights = w, + offset = offset, + terms = mod_terms, + xlevels = .getXlevels(mod_terms, mod_frame), + options = options + ) } else { # Since a matrix is requested, try to convert y but check # to see if it is possible @@ -389,7 +402,11 @@ maybe_matrix <- function(x) { } maybe_sparse_matrix <- function(x) { - if (any(vapply(x, sparsevctrs::is_sparse_vector, logical(1)))) { + if (methods::is(x, "sparseMatrix")) { + return(x) + } + + if (is_sparse_tibble(x)) { res <- sparsevctrs::coerce_to_sparse_matrix(x) } else { res <- as.matrix(x) diff --git a/R/fit.R b/R/fit.R index 0cfe7f5b0..52088a6b5 100644 --- a/R/fit.R +++ b/R/fit.R @@ -174,6 +174,8 @@ fit.model_spec <- eval_env$formula <- formula eval_env$weights <- wts + data <- materialize_sparse_tibble(data, object, "data") + fit_interface <- check_interface(eval_env$formula, eval_env$data, cl, object) diff --git a/R/fit_helpers.R b/R/fit_helpers.R index 625162771..6143c4505 100644 --- a/R/fit_helpers.R +++ b/R/fit_helpers.R @@ -129,6 +129,11 @@ form_xy <- function(object, control, env, indicators <- encoding_info %>% dplyr::pull(predictor_indicators) remove_intercept <- encoding_info %>% dplyr::pull(remove_intercept) + allow_sparse_x <- encoding_info %>% dplyr::pull(allow_sparse_x) + + if (allow_sparse_x && is_sparse_tibble(env$data)) { + target <- "dgCMatrix" + } data_obj <- .convert_form_to_xy_fit( formula = env$formula, diff --git a/R/sparsevctrs.R b/R/sparsevctrs.R index a7a65cf45..d09416ba3 100644 --- a/R/sparsevctrs.R +++ b/R/sparsevctrs.R @@ -7,6 +7,27 @@ to_sparse_data_frame <- function(x, object) { "{.arg x} is a sparse matrix, but {.fn {class(object)[1]}} with engine {.code {object$engine}} doesn't accept that.") } + } else if (is.data.frame(x)) { + x <- materialize_sparse_tibble(x, object, "x") + } + x +} + +is_sparse_tibble <- function(x) { + any(vapply(x, sparsevctrs::is_sparse_vector, logical(1))) +} + +materialize_sparse_tibble <- function(x, object, input) { + if ((!allow_sparse(object)) && is_sparse_tibble(x)) { + cli::cli_warn( + "{.arg {input}} is a sparse tibble, but {.fn {class(object)[1]}} with + engine {.code {object$engine}} doesn't accept that. Converting to + non-sparse." + ) + for (i in seq_along(ncol(x))) { + # materialize with [] + x[[i]] <- x[[i]][] + } } x } diff --git a/tests/testthat/_snaps/sparsevctrs.md b/tests/testthat/_snaps/sparsevctrs.md index 02cb9611b..3a849fbab 100644 --- a/tests/testthat/_snaps/sparsevctrs.md +++ b/tests/testthat/_snaps/sparsevctrs.md @@ -1,3 +1,19 @@ +# sparse tibble can be passed to `fit() + + Code + lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data[1:100, ]) + Condition + Warning: + `data` is a sparse tibble, but `linear_reg()` with engine `lm` doesn't accept that. Converting to non-sparse. + +# sparse tibble can be passed to `fit_xy() + + Code + lm_fit <- fit_xy(spec, x = hotel_data[1:100, -1], y = hotel_data[1:100, 1]) + Condition + Warning: + `x` is a sparse tibble, but `linear_reg()` with engine `lm` doesn't accept that. Converting to non-sparse. + # sparse matrices can be passed to `fit_xy() Code diff --git a/tests/testthat/test-sparsevctrs.R b/tests/testthat/test-sparsevctrs.R index 2f9027306..4eb876ae8 100644 --- a/tests/testthat/test-sparsevctrs.R +++ b/tests/testthat/test-sparsevctrs.R @@ -1,3 +1,49 @@ +test_that("sparse tibble can be passed to `fit()", { + skip_if_not_installed("xgboost") + + hotel_data <- sparse_hotel_rates() + hotel_data <- sparsevctrs::coerce_to_sparse_tibble(hotel_data) + + spec <- boost_tree() %>% + set_mode("regression") %>% + set_engine("xgboost") + + expect_no_error( + lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data) + ) + + spec <- linear_reg() %>% + set_mode("regression") %>% + set_engine("lm") + + expect_snapshot( + lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data[1:100, ]) + ) +}) + +test_that("sparse tibble can be passed to `fit_xy()", { + skip_if_not_installed("xgboost") + + hotel_data <- sparse_hotel_rates() + hotel_data <- sparsevctrs::coerce_to_sparse_tibble(hotel_data) + + spec <- boost_tree() %>% + set_mode("regression") %>% + set_engine("xgboost") + + expect_no_error( + lm_fit <- fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1]) + ) + + spec <- linear_reg() %>% + set_mode("regression") %>% + set_engine("lm") + + expect_snapshot( + lm_fit <- fit_xy(spec, x = hotel_data[1:100, -1], y = hotel_data[1:100, 1]) + ) +}) + test_that("sparse matrices can be passed to `fit_xy()", { skip_if_not_installed("xgboost") @@ -66,7 +112,7 @@ test_that("maybe_sparse_matrix() is used correctly", { local_mocked_bindings( maybe_sparse_matrix = function(x) { - if (any(vapply(x, sparsevctrs::is_sparse_vector, logical(1)))) { + if (is_sparse_tibble(x)) { stop("sparse vectors detected") } else { stop("no sparse vectors detected")