From 05414fc72c212cd95986bc978f2aa75750bec85e Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Thu, 29 Aug 2024 13:11:28 -0700 Subject: [PATCH 1/5] make fit_xy() work with sparse tibbles --- R/sparsevctrs.R | 13 +++++++++++++ tests/testthat/_snaps/sparsevctrs.md | 8 ++++++++ tests/testthat/test-sparsevctrs.R | 23 +++++++++++++++++++++++ 3 files changed, 44 insertions(+) diff --git a/R/sparsevctrs.R b/R/sparsevctrs.R index a7a65cf45..599155389 100644 --- a/R/sparsevctrs.R +++ b/R/sparsevctrs.R @@ -7,6 +7,19 @@ to_sparse_data_frame <- function(x, object) { "{.arg x} is a sparse matrix, but {.fn {class(object)[1]}} with engine {.code {object$engine}} doesn't accept that.") } + } else if (is.data.frame(x)) { + if ((!allow_sparse(object)) && + any(vapply(x, sparsevctrs::is_sparse_vector, logical(1)))) { + for (i in seq_along(ncol(x))) { + cli::cli_warn( + "{.arg x} is a sparse tibble, but {.fn {class(object)[1]}} with + engine {.code {object$engine}} doesn't accept that. Converting to + non-sparse." + ) + # materialize with [] + x[[i]] <- x[[i]][] + } + } } x } diff --git a/tests/testthat/_snaps/sparsevctrs.md b/tests/testthat/_snaps/sparsevctrs.md index 02cb9611b..880fee53e 100644 --- a/tests/testthat/_snaps/sparsevctrs.md +++ b/tests/testthat/_snaps/sparsevctrs.md @@ -1,3 +1,11 @@ +# sparse tibble can be passed to `fit_xy() + + Code + lm_fit <- fit_xy(spec, x = hotel_data[1:100, -1], y = hotel_data[1:100, 1]) + Condition + Warning: + `x` is a sparse tibble, but `linear_reg()` with engine `lm` doesn't accept that. Converting to non-sparse. + # sparse matrices can be passed to `fit_xy() Code diff --git a/tests/testthat/test-sparsevctrs.R b/tests/testthat/test-sparsevctrs.R index 2f9027306..0243ac1d4 100644 --- a/tests/testthat/test-sparsevctrs.R +++ b/tests/testthat/test-sparsevctrs.R @@ -1,3 +1,26 @@ +test_that("sparse tibble can be passed to `fit_xy()", { + skip_if_not_installed("xgboost") + + hotel_data <- sparse_hotel_rates() + hotel_data <- sparsevctrs::coerce_to_sparse_tibble(hotel_data) + + spec <- boost_tree() %>% + set_mode("regression") %>% + set_engine("xgboost") + + expect_no_error( + lm_fit <- fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1]) + ) + + spec <- linear_reg() %>% + set_mode("regression") %>% + set_engine("lm") + + expect_snapshot( + lm_fit <- fit_xy(spec, x = hotel_data[1:100, -1], y = hotel_data[1:100, 1]) + ) +}) + test_that("sparse matrices can be passed to `fit_xy()", { skip_if_not_installed("xgboost") From 6f06b786e3d049275c05e3a2c4a28324b648ae89 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Thu, 29 Aug 2024 15:07:48 -0700 Subject: [PATCH 2/5] make sparse tibbles work in fit() --- R/arguments.R | 1 + R/convert_data.R | 22 ++++++++++++++++++++-- R/fit.R | 13 +++++++++++++ R/fit_helpers.R | 5 +++++ R/sparsevctrs.R | 20 ++++++++++---------- tests/testthat/_snaps/sparsevctrs.md | 8 ++++++++ tests/testthat/test-sparsevctrs.R | 23 +++++++++++++++++++++++ 7 files changed, 80 insertions(+), 12 deletions(-) diff --git a/R/arguments.R b/R/arguments.R index 42721aac4..276ec3927 100644 --- a/R/arguments.R +++ b/R/arguments.R @@ -262,6 +262,7 @@ make_xy_call <- function(object, target, env) { none = rlang::expr(x), data.frame = rlang::expr(maybe_data_frame(x)), matrix = rlang::expr(maybe_matrix(x)), + dgCMatrix = rlang::expr(maybe_sparse_matrix(x)), rlang::abort(glue::glue("Invalid data type target: {target}.")) ) if (uses_weights) { diff --git a/R/convert_data.R b/R/convert_data.R index 64c93e02c..5fe1d528c 100644 --- a/R/convert_data.R +++ b/R/convert_data.R @@ -41,8 +41,10 @@ indicators = "traditional", composition = "data.frame", remove_intercept = TRUE) { - if (!(composition %in% c("data.frame", "matrix"))) { - rlang::abort("`composition` should be either 'data.frame' or 'matrix'.") + if (!(composition %in% c("data.frame", "matrix", "dgCMatrix"))) { + rlang::abort( + "`composition` should be either 'data.frame', 'matrix', or 'dgCMatrix'." + ) } if (remove_intercept) { @@ -120,6 +122,18 @@ xlevels = .getXlevels(mod_terms, mod_frame), options = options ) + } else if (composition == "dgCMatrix") { + x <- sparsevctrs::coerce_to_sparse_matrix(data) + res <- + list( + x = x, + y = y, + weights = w, + offset = offset, + terms = mod_terms, + xlevels = .getXlevels(mod_terms, mod_frame), + options = options + ) } else { # Since a matrix is requested, try to convert y but check # to see if it is possible @@ -381,6 +395,10 @@ maybe_matrix <- function(x) { } maybe_sparse_matrix <- function(x) { + if (methods::is(x, "sparseMatrix")) { + return(x) + } + if (any(vapply(x, sparsevctrs::is_sparse_vector, logical(1)))) { res <- sparsevctrs::coerce_to_sparse_matrix(x) } else { diff --git a/R/fit.R b/R/fit.R index 0cfe7f5b0..3d814154e 100644 --- a/R/fit.R +++ b/R/fit.R @@ -174,6 +174,19 @@ fit.model_spec <- eval_env$formula <- formula eval_env$weights <- wts + if ((!allow_sparse(object)) && + any(vapply(data, sparsevctrs::is_sparse_vector, logical(1)))) { + cli::cli_warn( + "{.arg data} is a sparse tibble, but {.fn {class(object)[1]}} with + engine {.code {object$engine}} doesn't accept that. Converting to + non-sparse." + ) + for (i in seq_along(ncol(data))) { + # materialize with [] + data[[i]] <- data[[i]][] + } + } + fit_interface <- check_interface(eval_env$formula, eval_env$data, cl, object) diff --git a/R/fit_helpers.R b/R/fit_helpers.R index fe7f75527..c4aeeb21a 100644 --- a/R/fit_helpers.R +++ b/R/fit_helpers.R @@ -125,6 +125,11 @@ form_xy <- function(object, control, env, indicators <- encoding_info %>% dplyr::pull(predictor_indicators) remove_intercept <- encoding_info %>% dplyr::pull(remove_intercept) + allow_sparse_x <- encoding_info %>% dplyr::pull(allow_sparse_x) + + if (allow_sparse_x && any(vapply(env$data, sparsevctrs::is_sparse_vector, logical(1)))) { + target <- "dgCMatrix" + } data_obj <- .convert_form_to_xy_fit( formula = env$formula, diff --git a/R/sparsevctrs.R b/R/sparsevctrs.R index 599155389..a0faa90e2 100644 --- a/R/sparsevctrs.R +++ b/R/sparsevctrs.R @@ -10,16 +10,16 @@ to_sparse_data_frame <- function(x, object) { } else if (is.data.frame(x)) { if ((!allow_sparse(object)) && any(vapply(x, sparsevctrs::is_sparse_vector, logical(1)))) { - for (i in seq_along(ncol(x))) { - cli::cli_warn( - "{.arg x} is a sparse tibble, but {.fn {class(object)[1]}} with - engine {.code {object$engine}} doesn't accept that. Converting to - non-sparse." - ) - # materialize with [] - x[[i]] <- x[[i]][] - } + cli::cli_warn( + "{.arg x} is a sparse tibble, but {.fn {class(object)[1]}} with + engine {.code {object$engine}} doesn't accept that. Converting to + non-sparse." + ) + for (i in seq_along(ncol(x))) { + # materialize with [] + x[[i]] <- x[[i]][] + } } } x -} +} \ No newline at end of file diff --git a/tests/testthat/_snaps/sparsevctrs.md b/tests/testthat/_snaps/sparsevctrs.md index 880fee53e..3a849fbab 100644 --- a/tests/testthat/_snaps/sparsevctrs.md +++ b/tests/testthat/_snaps/sparsevctrs.md @@ -1,3 +1,11 @@ +# sparse tibble can be passed to `fit() + + Code + lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data[1:100, ]) + Condition + Warning: + `data` is a sparse tibble, but `linear_reg()` with engine `lm` doesn't accept that. Converting to non-sparse. + # sparse tibble can be passed to `fit_xy() Code diff --git a/tests/testthat/test-sparsevctrs.R b/tests/testthat/test-sparsevctrs.R index 0243ac1d4..821a8a085 100644 --- a/tests/testthat/test-sparsevctrs.R +++ b/tests/testthat/test-sparsevctrs.R @@ -1,3 +1,26 @@ +test_that("sparse tibble can be passed to `fit()", { + skip_if_not_installed("xgboost") + + hotel_data <- sparse_hotel_rates() + hotel_data <- sparsevctrs::coerce_to_sparse_tibble(hotel_data) + + spec <- boost_tree() %>% + set_mode("regression") %>% + set_engine("xgboost") + + expect_no_error( + lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data) + ) + + spec <- linear_reg() %>% + set_mode("regression") %>% + set_engine("lm") + + expect_snapshot( + lm_fit <- fit(spec, avg_price_per_room ~ ., data = hotel_data[1:100, ]) + ) +}) + test_that("sparse tibble can be passed to `fit_xy()", { skip_if_not_installed("xgboost") From cc4e2e529bfb07d59a7b10c93bd8e1aff1d81c4d Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Thu, 29 Aug 2024 15:09:49 -0700 Subject: [PATCH 3/5] add news --- NEWS.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/NEWS.md b/NEWS.md index fcde44d0b..98e9d1681 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,8 @@ * `fit_xy()` can now take dgCMatrix input for `x` argument (#1121). +* `fit()` and `fit_xy()` can now take sparse tibbles as data values (#1165). + * Transitioned package errors and warnings to use cli (#1147 and #1148 by @shum461, #1153 by @RobLBaker and @wright13, #1154 by @JamesHWade, #1160, #1161). From bf5c65715cd652bb8aeb9a6fb585e31d916f08aa Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Thu, 29 Aug 2024 15:57:22 -0700 Subject: [PATCH 4/5] refactor out is_sparse_tibble() --- R/convert_data.R | 2 +- R/fit.R | 3 +-- R/fit_helpers.R | 2 +- R/sparsevctrs.R | 7 +++++-- tests/testthat/test-sparsevctrs.R | 2 +- 5 files changed, 9 insertions(+), 7 deletions(-) diff --git a/R/convert_data.R b/R/convert_data.R index 895f6c3f6..73fea702d 100644 --- a/R/convert_data.R +++ b/R/convert_data.R @@ -406,7 +406,7 @@ maybe_sparse_matrix <- function(x) { return(x) } - if (any(vapply(x, sparsevctrs::is_sparse_vector, logical(1)))) { + if (is_sparse_tibble(x)) { res <- sparsevctrs::coerce_to_sparse_matrix(x) } else { res <- as.matrix(x) diff --git a/R/fit.R b/R/fit.R index 3d814154e..ddb83df0e 100644 --- a/R/fit.R +++ b/R/fit.R @@ -174,8 +174,7 @@ fit.model_spec <- eval_env$formula <- formula eval_env$weights <- wts - if ((!allow_sparse(object)) && - any(vapply(data, sparsevctrs::is_sparse_vector, logical(1)))) { + if ((!allow_sparse(object)) && is_sparse_tibble(data)) { cli::cli_warn( "{.arg data} is a sparse tibble, but {.fn {class(object)[1]}} with engine {.code {object$engine}} doesn't accept that. Converting to diff --git a/R/fit_helpers.R b/R/fit_helpers.R index b037dde19..6143c4505 100644 --- a/R/fit_helpers.R +++ b/R/fit_helpers.R @@ -131,7 +131,7 @@ form_xy <- function(object, control, env, remove_intercept <- encoding_info %>% dplyr::pull(remove_intercept) allow_sparse_x <- encoding_info %>% dplyr::pull(allow_sparse_x) - if (allow_sparse_x && any(vapply(env$data, sparsevctrs::is_sparse_vector, logical(1)))) { + if (allow_sparse_x && is_sparse_tibble(env$data)) { target <- "dgCMatrix" } diff --git a/R/sparsevctrs.R b/R/sparsevctrs.R index a0faa90e2..08ba5c3a9 100644 --- a/R/sparsevctrs.R +++ b/R/sparsevctrs.R @@ -8,8 +8,7 @@ to_sparse_data_frame <- function(x, object) { engine {.code {object$engine}} doesn't accept that.") } } else if (is.data.frame(x)) { - if ((!allow_sparse(object)) && - any(vapply(x, sparsevctrs::is_sparse_vector, logical(1)))) { + if ((!allow_sparse(object)) && is_sparse_tibble(x)) { cli::cli_warn( "{.arg x} is a sparse tibble, but {.fn {class(object)[1]}} with engine {.code {object$engine}} doesn't accept that. Converting to @@ -22,4 +21,8 @@ to_sparse_data_frame <- function(x, object) { } } x +} + +is_sparse_tibble <- function(x) { + any(vapply(x, sparsevctrs::is_sparse_vector, logical(1))) } \ No newline at end of file diff --git a/tests/testthat/test-sparsevctrs.R b/tests/testthat/test-sparsevctrs.R index 821a8a085..4eb876ae8 100644 --- a/tests/testthat/test-sparsevctrs.R +++ b/tests/testthat/test-sparsevctrs.R @@ -112,7 +112,7 @@ test_that("maybe_sparse_matrix() is used correctly", { local_mocked_bindings( maybe_sparse_matrix = function(x) { - if (any(vapply(x, sparsevctrs::is_sparse_vector, logical(1)))) { + if (is_sparse_tibble(x)) { stop("sparse vectors detected") } else { stop("no sparse vectors detected") From 60edee7cd516d4928a2f7a392de5a0e7ed0ffe3d Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Thu, 29 Aug 2024 16:31:58 -0700 Subject: [PATCH 5/5] refactor out materialize_sparse_tibble() --- R/fit.R | 14 ++------------ R/sparsevctrs.R | 29 +++++++++++++++++------------ 2 files changed, 19 insertions(+), 24 deletions(-) diff --git a/R/fit.R b/R/fit.R index ddb83df0e..52088a6b5 100644 --- a/R/fit.R +++ b/R/fit.R @@ -174,18 +174,8 @@ fit.model_spec <- eval_env$formula <- formula eval_env$weights <- wts - if ((!allow_sparse(object)) && is_sparse_tibble(data)) { - cli::cli_warn( - "{.arg data} is a sparse tibble, but {.fn {class(object)[1]}} with - engine {.code {object$engine}} doesn't accept that. Converting to - non-sparse." - ) - for (i in seq_along(ncol(data))) { - # materialize with [] - data[[i]] <- data[[i]][] - } - } - + data <- materialize_sparse_tibble(data, object, "data") + fit_interface <- check_interface(eval_env$formula, eval_env$data, cl, object) diff --git a/R/sparsevctrs.R b/R/sparsevctrs.R index 08ba5c3a9..d09416ba3 100644 --- a/R/sparsevctrs.R +++ b/R/sparsevctrs.R @@ -8,21 +8,26 @@ to_sparse_data_frame <- function(x, object) { engine {.code {object$engine}} doesn't accept that.") } } else if (is.data.frame(x)) { - if ((!allow_sparse(object)) && is_sparse_tibble(x)) { - cli::cli_warn( - "{.arg x} is a sparse tibble, but {.fn {class(object)[1]}} with - engine {.code {object$engine}} doesn't accept that. Converting to - non-sparse." - ) - for (i in seq_along(ncol(x))) { - # materialize with [] - x[[i]] <- x[[i]][] - } - } + x <- materialize_sparse_tibble(x, object, "x") } x } is_sparse_tibble <- function(x) { any(vapply(x, sparsevctrs::is_sparse_vector, logical(1))) -} \ No newline at end of file +} + +materialize_sparse_tibble <- function(x, object, input) { + if ((!allow_sparse(object)) && is_sparse_tibble(x)) { + cli::cli_warn( + "{.arg {input}} is a sparse tibble, but {.fn {class(object)[1]}} with + engine {.code {object$engine}} doesn't accept that. Converting to + non-sparse." + ) + for (i in seq_along(ncol(x))) { + # materialize with [] + x[[i]] <- x[[i]][] + } + } + x +}