Skip to content

fit() no longer drops the epi_workflow class #377

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: epipredict
Title: Basic epidemiology forecasting methods
Version: 0.0.22
Version: 0.0.23
Authors@R: c(
person("Daniel", "McDonald", , "daniel@stat.ubc.ca", role = c("aut", "cre")),
person("Ryan", "Tibshirani", , "ryantibs@cmu.edu", role = "aut"),
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ importFrom(rlang,":=")
importFrom(rlang,abort)
importFrom(rlang,arg_match)
importFrom(rlang,as_function)
importFrom(rlang,caller_arg)
importFrom(rlang,caller_env)
importFrom(rlang,enquo)
importFrom(rlang,enquos)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,5 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat
- Add `step_epi_slide` to produce generic sliding computations over an `epi_df`
- Add quantile random forests (via `{grf}`) as a parsnip engine
- Replace `epi_keys()` with `epiprocess::key_colnames()`, #352
- Fix bug where `fit()` drops the `epi_workflow` class, #363
- Try to retain the `epi_df` class during baking to the extent possible, #376
24 changes: 9 additions & 15 deletions R/epi_recipe.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,10 @@ epi_recipe <- function(x, ...) {
#' @rdname epi_recipe
#' @export
epi_recipe.default <- function(x, ...) {
## if not a formula or an epi_df, we just pass to recipes::recipe
if (is.matrix(x) || is.data.frame(x) || tibble::is_tibble(x)) {
x <- x[1, , drop = FALSE]
}
cli_warn(
"epi_recipe has been called with a non-epi_df object, returning a regular recipe. Various
step_epi_* functions will not work."
)
recipes::recipe(x, ...)
cli_abort(paste(
"`x` must be an {.cls epi_df} or a {.cls formula},",
"not a {.cls {class(x)[[1]]}}."
))
}

#' @rdname epi_recipe
Expand Down Expand Up @@ -154,17 +149,16 @@ epi_recipe.formula <- function(formula, data, ...) {
data <- data[1, ]
# check for minus:
if (!epiprocess::is_epi_df(data)) {
cli_warn(
"epi_recipe has been called with a non-epi_df object, returning a regular recipe. Various
step_epi_* functions will not work."
)
return(recipes::recipe(formula, data, ...))
cli_abort(paste(
"`epi_recipe()` has been called with a non-{.cls epi_df} object.",
"Use `recipe()` instead."
))
}

attr(data, "decay_to_tibble") <- FALSE
f_funcs <- recipes:::fun_calls(formula, data)
if (any(f_funcs == "-")) {
abort("`-` is not allowed in a recipe formula. Use `step_rm()` instead.")
cli_abort("`-` is not allowed in a recipe formula. Use `step_rm()` instead.")
}

# Check for other in-line functions
Expand Down
4 changes: 3 additions & 1 deletion R/epi_workflow.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ fit.epi_workflow <- function(object, data, ..., control = workflows::control_wor
)
object$original_data <- data

NextMethod()
res <- NextMethod()
class(res) <- c("epi_workflow", class(res))
res
}

#' Predict from an epi_workflow
Expand Down
13 changes: 7 additions & 6 deletions R/epipredict-package.R
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
## usethis namespace: start
#' @importFrom tibble as_tibble
#' @importFrom rlang := !! %||% as_function global_env set_names !!!
#' is_logical is_true inject enquo enquos expr sym arg_match
#' @importFrom rlang := !! %||% as_function global_env set_names !!! caller_arg
#' @importFrom rlang is_logical is_true inject enquo enquos expr sym arg_match
#' @importFrom stats poly predict lm residuals quantile
#' @importFrom dplyr arrange across all_of any_of bind_cols bind_rows group_by
#' summarize filter mutate select left_join rename ungroup full_join
#' relocate summarise everything
#' @importFrom dplyr summarize filter mutate select left_join rename ungroup
#' @importFrom dplyr full_join relocate summarise everything
#' @importFrom cli cli_abort cli_warn
#' @importFrom checkmate assert assert_character assert_int assert_scalar
#' assert_logical assert_numeric assert_number assert_integer
#' assert_integerish assert_date assert_function assert_class
#' @importFrom checkmate assert_logical assert_numeric assert_number
#' @importFrom checkmate assert_integer assert_integerish
#' @importFrom checkmate assert_date assert_function assert_class
#' @import epiprocess parsnip
## usethis namespace: end
NULL
24 changes: 24 additions & 0 deletions tests/testthat/_snaps/epi_recipe.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# epi_recipe produces error if not an epi_df

Code
epi_recipe(tib)
Condition
Error in `epi_recipe()`:
! `x` must be an <epi_df> or a <formula>, not a <tbl_df>.

---

Code
epi_recipe(y ~ x, tib)
Condition
Error in `epi_recipe()`:
! `epi_recipe()` has been called with a non-<epi_df> object. Use `recipe()` instead.

---

Code
epi_recipe(m)
Condition
Error in `epi_recipe()`:
! `x` must be an <epi_df> or a <formula>, not a <matrix>.

16 changes: 16 additions & 0 deletions tests/testthat/_snaps/epi_workflow.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# fit method does not silently drop the class

Code
epi_recipe(y ~ x, data = tbl)
Condition
Error in `epi_recipe()`:
! `epi_recipe()` has been called with a non-<epi_df> object. Use `recipe()` instead.

---

Code
ewf_erec_edf %>% fit(tbl)
Condition
Error in `if (new_meta != old_meta) ...`:
! argument is of length zero

23 changes: 4 additions & 19 deletions tests/testthat/test-epi_recipe.R
Original file line number Diff line number Diff line change
@@ -1,27 +1,12 @@
test_that("epi_recipe produces default recipe", {
# these all call recipes::recipe(), but the template will always have 1 row
test_that("epi_recipe produces error if not an epi_df", {
tib <- tibble(
x = 1:5, y = 1:5,
time_value = seq(as.Date("2020-01-01"), by = 1, length.out = 5)
)
expected_rec <- recipes::recipe(tib)
expected_rec$template <- expected_rec$template[1, ]
expect_warning(rec <- epi_recipe(tib), regexp = "epi_recipe has been called with a non-epi_df object")
expect_identical(expected_rec, rec)
expect_equal(nrow(rec$template), 1L)

expected_rec <- recipes::recipe(y ~ x, tib)
expected_rec$template <- expected_rec$template[1, ]
expect_warning(rec <- epi_recipe(y ~ x, tib), regexp = "epi_recipe has been called with a non-epi_df object")
expect_identical(expected_rec, rec)
expect_equal(nrow(rec$template), 1L)

expect_snapshot(error = TRUE, epi_recipe(tib))
expect_snapshot(error = TRUE, epi_recipe(y ~ x, tib))
m <- as.matrix(tib)
expected_rec <- recipes::recipe(m)
expected_rec$template <- expected_rec$template[1, ]
expect_warning(rec <- epi_recipe(m), regexp = "epi_recipe has been called with a non-epi_df object")
expect_identical(expected_rec, rec)
expect_equal(nrow(rec$template), 1L)
expect_snapshot(error = TRUE, epi_recipe(m))
})

test_that("epi_recipe formula works", {
Expand Down
37 changes: 37 additions & 0 deletions tests/testthat/test-epi_workflow.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,40 @@ test_that("forecast method errors when workflow not fit", {

expect_error(forecast(wf))
})

test_that("fit method does not silently drop the class", {
# This is issue #363

library(recipes)
tbl <- tibble::tibble(
geo_value = 1,
time_value = 1:100,
x = 1:100,
y = x + rnorm(100L)
)
edf <- as_epi_df(tbl)

rec_tbl <- recipe(y ~ x, data = tbl)
rec_edf <- recipe(y ~ x, data = edf)
expect_snapshot(error = TRUE, epi_recipe(y ~ x, data = tbl))
erec_edf <- epi_recipe(y ~ x, data = edf)

ewf_rec_tbl <- epi_workflow(rec_tbl, linear_reg())
ewf_rec_edf <- epi_workflow(rec_edf, linear_reg())
ewf_erec_edf <- epi_workflow(erec_edf, linear_reg())

# above are all epi_workflows:

expect_s3_class(ewf_rec_tbl, "epi_workflow")
expect_s3_class(ewf_rec_edf, "epi_workflow")
expect_s3_class(ewf_erec_edf, "epi_workflow")

# but fitting drops the class or generates errors in many cases:

expect_s3_class(ewf_rec_tbl %>% fit(tbl), "epi_workflow")
expect_s3_class(ewf_rec_tbl %>% fit(edf), "epi_workflow")
expect_s3_class(ewf_rec_edf %>% fit(tbl), "epi_workflow")
expect_s3_class(ewf_rec_edf %>% fit(edf), "epi_workflow")
expect_snapshot(ewf_erec_edf %>% fit(tbl), error = TRUE)
expect_s3_class(ewf_erec_edf %>% fit(edf), "epi_workflow")
})
Loading