diff --git a/DESCRIPTION b/DESCRIPTION index 593f1038..1994f2cc 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: epipredict Title: Basic epidemiology forecasting methods -Version: 0.1.9 +Version: 0.1.10 Authors@R: c( person("Daniel J.", "McDonald", , "daniel@stat.ubc.ca", role = c("aut", "cre")), person("Ryan", "Tibshirani", , "ryantibs@cmu.edu", role = "aut"), @@ -25,7 +25,7 @@ URL: https://github.com/cmu-delphi/epipredict/, BugReports: https://github.com/cmu-delphi/epipredict/issues/ Depends: epidatasets, - epiprocess (>= 0.9.0), + epiprocess (>= 0.10.4), parsnip (>= 1.0.0), R (>= 3.5.0) Imports: @@ -73,7 +73,6 @@ Remotes: cmu-delphi/epidatasets, cmu-delphi/epidatr, cmu-delphi/epiprocess, - cmu-delphi/epidatasets, dajmcdon/smoothqr Config/Needs/website: cmu-delphi/delphidocs Config/testthat/edition: 3 diff --git a/NEWS.md b/NEWS.md index 6b98d4d5..28f5bb99 100644 --- a/NEWS.md +++ b/NEWS.md @@ -11,6 +11,9 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat `data()`, but can be accessed with `data(, package = "epidatasets")`, `epidatasets::` or, after loading the package, the name of the dataset alone (#382). +- Addresses upstream breaking changes from cmu-delphi/epiprocess#595 (`growth_rate()`). + `step_growth_rate()` has lost its `additional_gr_args_list` argument and now + has an `na_rm` argument. ## Improvements diff --git a/R/arx_classifier.R b/R/arx_classifier.R index 240bc69e..1c4f35ef 100644 --- a/R/arx_classifier.R +++ b/R/arx_classifier.R @@ -27,8 +27,9 @@ #' #' @examples #' library(dplyr) +#' tiny_geos <- c("as", "mp", "vi", "gu", "pr") #' jhu <- covid_case_death_rates %>% -#' filter(time_value >= as.Date("2021-11-01")) +#' filter(time_value >= as.Date("2021-11-01"), !(geo_value %in% tiny_geos)) #' #' out <- arx_classifier(jhu, "death_rate", c("case_rate", "death_rate")) #' @@ -58,7 +59,10 @@ arx_classifier <- function( if (args_list$adjust_latency == "none") { forecast_date_default <- max(epi_data$time_value) if (!is.null(args_list$forecast_date) && args_list$forecast_date != forecast_date_default) { - cli_warn("The specified forecast date {args_list$forecast_date} doesn't match the date from which the forecast is occurring {forecast_date}.") + cli_warn( + "The specified forecast date {args_list$forecast_date} doesn't match the + date from which the forecast is occurring {forecast_date}." + ) } } else { forecast_date_default <- attributes(epi_data)$metadata$as_of @@ -101,7 +105,7 @@ arx_classifier <- function( #' #' @return An unfit `epi_workflow`. #' @export -#' @seealso [arx_classifier()] +#' @seealso [arx_classifier()] [arx_class_args_list()] #' @examples #' library(dplyr) #' jhu <- covid_case_death_rates %>% @@ -154,12 +158,13 @@ arx_class_epi_workflow <- function( role = "grp", horizon = args_list$horizon, method = args_list$method, - log_scale = args_list$log_scale, - additional_gr_args_list = args_list$additional_gr_args + log_scale = args_list$log_scale ) for (l in seq_along(lags)) { pred_names <- predictors[l] - pred_names <- as.character(glue::glue_data(args_list, "gr_{horizon}_{method}_{pred_names}")) + pred_names <- as.character(glue::glue_data( + args_list, "gr_{horizon}_{method}_{pred_names}" + )) r <- step_epi_lag(r, !!pred_names, lag = lags[[l]]) } # ------- outcome @@ -185,8 +190,7 @@ arx_class_epi_workflow <- function( role = "pre-outcome", horizon = args_list$horizon, method = args_list$method, - log_scale = args_list$log_scale, - additional_gr_args_list = args_list$additional_gr_args + log_scale = args_list$log_scale ) } } @@ -270,9 +274,6 @@ arx_class_epi_workflow <- function( #' @param method Character. Options available for growth rate calculation. #' @param log_scale Scalar logical. Whether to compute growth rates on the #' log scale. -#' @param additional_gr_args List. Optional arguments controlling growth rate -#' calculation. See [epiprocess::growth_rate()] and the related Vignette for -#' more details. #' @param check_enough_data_n Integer. A lower limit for the number of rows per #' epi_key that are required for training. If `NULL`, this check is ignored. #' @param check_enough_data_epi_keys Character vector. A character vector of @@ -301,7 +302,6 @@ arx_class_args_list <- function( horizon = 7L, method = c("rel_change", "linear_reg"), log_scale = FALSE, - additional_gr_args = list(), check_enough_data_n = NULL, check_enough_data_epi_keys = NULL, ...) { @@ -320,23 +320,14 @@ arx_class_args_list <- function( arg_is_lgl(log_scale) arg_is_pos(n_training) if (is.finite(n_training)) arg_is_pos_int(n_training) - if (!is.list(additional_gr_args)) { - cli_abort(c( - "`additional_gr_args` must be a {.cls list}.", - "!" = "This is a {.cls {class(additional_gr_args)}}.", - i = "See `?epiprocess::growth_rate` for available arguments." - )) - } arg_is_pos(check_enough_data_n, allow_null = TRUE) arg_is_chr(check_enough_data_epi_keys, allow_null = TRUE) if (!is.null(forecast_date) && !is.null(target_date)) { if (forecast_date + ahead != target_date) { cli_warn( - paste0( - "`forecast_date` {.val {forecast_date}} +", - " `ahead` {.val {ahead}} must equal `target_date` {.val {target_date}}." - ), + "`forecast_date` {.val {forecast_date}} + + `ahead` {.val {ahead}} must equal `target_date` {.val {target_date}}.", class = "epipredict__arx_args__inconsistent_target_ahead_forecaste_date" ) } @@ -362,7 +353,6 @@ arx_class_args_list <- function( horizon, method, log_scale, - additional_gr_args, check_enough_data_n, check_enough_data_epi_keys ), diff --git a/R/step_growth_rate.R b/R/step_growth_rate.R index b3a71231..159c2439 100644 --- a/R/step_growth_rate.R +++ b/R/step_growth_rate.R @@ -22,9 +22,7 @@ #' being removed from the data. Alternatively, you could specify arbitrary #' large values, or perhaps zero. Setting this argument to `NULL` will result #' in no replacement. -#' @param additional_gr_args_list A list of additional arguments used by -#' [epiprocess::growth_rate()]. All `...` arguments may be passed here along -#' with `dup_rm` and `na_rm`. +#' @inheritParams epiprocess::growth_rate #' @template step-return #' #' @@ -32,12 +30,17 @@ #' @importFrom epiprocess growth_rate #' @export #' @examples -#' r <- epi_recipe(covid_case_death_rates) %>% +#' library(dplyr) +#' tiny_geos <- c("as", "mp", "vi", "gu", "pr") +#' rates <- covid_case_death_rates %>% +#' filter(time_value >= as.Date("2021-11-01"), !(geo_value %in% tiny_geos)) +#' +#' r <- epi_recipe(rates) %>% #' step_growth_rate(case_rate, death_rate) #' r #' #' r %>% -#' prep(covid_case_death_rates) %>% +#' prep(rates) %>% #' bake(new_data = NULL) step_growth_rate <- function(recipe, @@ -46,11 +49,11 @@ step_growth_rate <- horizon = 7, method = c("rel_change", "linear_reg"), log_scale = FALSE, + na_rm = TRUE, replace_Inf = NA, prefix = "gr_", skip = FALSE, - id = rand_id("growth_rate"), - additional_gr_args_list = list()) { + id = rand_id("growth_rate")) { if (!is_epi_recipe(recipe)) { cli_abort("This recipe step can only operate on an {.cls epi_recipe}.") } @@ -63,15 +66,7 @@ step_growth_rate <- } arg_is_chr(role) arg_is_chr_scalar(prefix, id) - arg_is_lgl_scalar(log_scale, skip) - - - if (!is.list(additional_gr_args_list)) { - cli_abort(c( - "`additional_gr_args_list` must be a {.cls list}.", - i = "See `?epiprocess::growth_rate` for available options." - )) - } + arg_is_lgl_scalar(log_scale, skip, na_rm) recipes::add_step( recipe, @@ -82,13 +77,13 @@ step_growth_rate <- horizon = horizon, method = method, log_scale = log_scale, + na_rm = na_rm, replace_Inf = replace_Inf, prefix = prefix, keys = key_colnames(recipe), columns = NULL, skip = skip, - id = id, - additional_gr_args_list = additional_gr_args_list + id = id ) ) } @@ -101,13 +96,13 @@ step_growth_rate_new <- horizon, method, log_scale, + na_rm, replace_Inf, prefix, keys, columns, skip, - id, - additional_gr_args_list) { + id) { recipes::step( subclass = "growth_rate", terms = terms, @@ -116,13 +111,13 @@ step_growth_rate_new <- horizon = horizon, method = method, log_scale = log_scale, + na_rm = na_rm, replace_Inf = replace_Inf, prefix = prefix, keys = keys, columns = columns, skip = skip, - id = id, - additional_gr_args_list = additional_gr_args_list + id = id ) } @@ -137,13 +132,13 @@ prep.step_growth_rate <- function(x, training, info = NULL, ...) { horizon = x$horizon, method = x$method, log_scale = x$log_scale, + na_rm = x$na_rm, replace_Inf = x$replace_Inf, prefix = x$prefix, keys = x$keys, columns = recipes::recipes_eval_select(x$terms, training, info), skip = x$skip, - id = x$id, - additional_gr_args_list = x$additional_gr_args_list + id = x$id ) } @@ -177,10 +172,12 @@ bake.step_growth_rate <- function(object, new_data, ...) { across( all_of(object$columns), ~ epiprocess::growth_rate( - time_value, .x, + .x, + x = time_value, method = object$method, - h = object$horizon, log_scale = object$log_scale, - !!!object$additional_gr_args_list + h = object$horizon, + log_scale = object$log_scale, + na_rm = object$na_rm ), .names = "{object$prefix}{object$horizon}_{object$method}_{.col}" ) diff --git a/man/arx_class_args_list.Rd b/man/arx_class_args_list.Rd index dbf27535..40bb48ca 100644 --- a/man/arx_class_args_list.Rd +++ b/man/arx_class_args_list.Rd @@ -17,7 +17,6 @@ arx_class_args_list( horizon = 7L, method = c("rel_change", "linear_reg"), log_scale = FALSE, - additional_gr_args = list(), check_enough_data_n = NULL, check_enough_data_epi_keys = NULL, ... @@ -96,10 +95,6 @@ calculate the growth rate.} \item{log_scale}{Scalar logical. Whether to compute growth rates on the log scale.} -\item{additional_gr_args}{List. Optional arguments controlling growth rate -calculation. See \code{\link[epiprocess:growth_rate]{epiprocess::growth_rate()}} and the related Vignette for -more details.} - \item{check_enough_data_n}{Integer. A lower limit for the number of rows per epi_key that are required for training. If \code{NULL}, this check is ignored.} diff --git a/man/arx_class_epi_workflow.Rd b/man/arx_class_epi_workflow.Rd index 9f0aae6a..7497fe95 100644 --- a/man/arx_class_epi_workflow.Rd +++ b/man/arx_class_epi_workflow.Rd @@ -65,5 +65,5 @@ arx_class_epi_workflow( ) } \seealso{ -\code{\link[=arx_classifier]{arx_classifier()}} +\code{\link[=arx_classifier]{arx_classifier()}} \code{\link[=arx_class_args_list]{arx_class_args_list()}} } diff --git a/man/arx_classifier.Rd b/man/arx_classifier.Rd index 94503f3d..22edc98e 100644 --- a/man/arx_classifier.Rd +++ b/man/arx_classifier.Rd @@ -49,8 +49,9 @@ that it estimates a class at a particular target horizon. } \examples{ library(dplyr) +tiny_geos <- c("as", "mp", "vi", "gu", "pr") jhu <- covid_case_death_rates \%>\% - filter(time_value >= as.Date("2021-11-01")) + filter(time_value >= as.Date("2021-11-01"), !(geo_value \%in\% tiny_geos)) out <- arx_classifier(jhu, "death_rate", c("case_rate", "death_rate")) diff --git a/man/epi_recipe.Rd b/man/epi_recipe.Rd index 98775eae..b93e4bcf 100644 --- a/man/epi_recipe.Rd +++ b/man/epi_recipe.Rd @@ -56,17 +56,17 @@ anything but common roles are \code{"outcome"}, \code{"predictor"}, } \value{ An object of class \code{recipe} with sub-objects: -\item{var_info}{A tibble containing information about the original data -set columns} +\item{var_info}{A tibble containing information about the original data set +columns.} \item{term_info}{A tibble that contains the current set of terms in the data set. This initially defaults to the same data contained in \code{var_info}.} -\item{steps}{A list of \code{step} or \code{check} objects that define the sequence of -preprocessing operations that will be applied to data. The default value is -\code{NULL}} -\item{template}{A tibble of the data. This is initialized to be the same -as the data given in the \code{data} argument but can be different after -the recipe is trained.} +\item{steps}{A list of \code{step} or \code{check} objects that define the sequence +of preprocessing operations that will be applied to data. The default value +is \code{NULL}.} +\item{template}{A tibble of the data. This is initialized to be the same as +the data given in the \code{data} argument but can be different after the recipe +is trained.} } \description{ A recipe is a description of the steps to be applied to a data set in diff --git a/man/step_growth_rate.Rd b/man/step_growth_rate.Rd index 12963f8d..2c98e74e 100644 --- a/man/step_growth_rate.Rd +++ b/man/step_growth_rate.Rd @@ -11,11 +11,11 @@ step_growth_rate( horizon = 7, method = c("rel_change", "linear_reg"), log_scale = FALSE, + na_rm = TRUE, replace_Inf = NA, prefix = "gr_", skip = FALSE, - id = rand_id("growth_rate"), - additional_gr_args_list = list() + id = rand_id("growth_rate") ) } \arguments{ @@ -41,6 +41,9 @@ growth rates). See \code{\link[epiprocess:growth_rate]{epiprocess::growth_rate() \item{log_scale}{Should growth rates be estimated using the parameterization on the log scale? See details for an explanation. Default is \code{FALSE}.} +\item{na_rm}{Should missing values be removed before the computation? Default +is \code{FALSE}.} + \item{replace_Inf}{Sometimes, the growth rate calculation can result in infinite values (if the denominator is zero, for example). In this case, most prediction methods will fail. This argument specifies potential @@ -59,10 +62,6 @@ Care should be taken when using \code{skip = TRUE} as it may affect the computations for subsequent operations.} \item{id}{A unique identifier for the step} - -\item{additional_gr_args_list}{A list of additional arguments used by -\code{\link[epiprocess:growth_rate]{epiprocess::growth_rate()}}. All \code{...} arguments may be passed here along -with \code{dup_rm} and \code{na_rm}.} } \value{ An updated version of \code{recipe} with the new step added to the @@ -73,12 +72,17 @@ sequence of any existing operations. that will generate one or more new columns of derived data. } \examples{ -r <- epi_recipe(covid_case_death_rates) \%>\% +library(dplyr) +tiny_geos <- c("as", "mp", "vi", "gu", "pr") +rates <- covid_case_death_rates \%>\% + filter(time_value >= as.Date("2021-11-01"), !(geo_value \%in\% tiny_geos)) + +r <- epi_recipe(rates) \%>\% step_growth_rate(case_rate, death_rate) r r \%>\% - prep(covid_case_death_rates) \%>\% + prep(rates) \%>\% bake(new_data = NULL) } \seealso{ diff --git a/tests/testthat/_snaps/layers.md b/tests/testthat/_snaps/layers.md index a0474eab..7f208f2e 100644 --- a/tests/testthat/_snaps/layers.md +++ b/tests/testthat/_snaps/layers.md @@ -3,7 +3,7 @@ Code update(f$layers[[1]], lower = 100) Condition - Error in `recipes:::update_fields()`: + Error in `update()`: ! The step you are trying to update, `layer_predict()`, does not have the lower field. --- @@ -19,6 +19,6 @@ Code update(f$layers[[2]], bad_param = 100) Condition - Error in `recipes:::update_fields()`: + Error in `update()`: ! The step you are trying to update, `layer_threshold()`, does not have the bad_param field. diff --git a/tests/testthat/_snaps/snapshots.md b/tests/testthat/_snaps/snapshots.md index 17191e04..d1cf2df7 100644 --- a/tests/testthat/_snaps/snapshots.md +++ b/tests/testthat/_snaps/snapshots.md @@ -1237,54 +1237,52 @@ # arx_classifier snapshots structure(list(geo_value = c("ak", "al", "ar", "az", "ca", "co", - "ct", "dc", "de", "fl", "ga", "gu", "hi", "ia", "id", "il", "in", - "ks", "ky", "la", "ma", "me", "mi", "mn", "mo", "mp", "ms", "mt", - "nc", "nd", "ne", "nh", "nj", "nm", "nv", "ny", "oh", "ok", "or", - "pa", "pr", "ri", "sc", "sd", "tn", "tx", "ut", "va", "vt", "wa", - "wi", "wv", "wy"), .pred_class = structure(c(1L, 1L, 1L, 1L, + "ct", "dc", "de", "fl", "ga", "hi", "ia", "id", "il", "in", "ks", + "ky", "la", "ma", "me", "mi", "mn", "mo", "ms", "mt", "nc", "nd", + "ne", "nh", "nj", "nm", "nv", "ny", "oh", "ok", "or", "pa", "pr", + "ri", "sc", "sd", "tn", "tx", "ut", "va", "vt", "wa", "wi", "wv", + "wy"), .pred_class = structure(c(1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, - 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, - 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, - 1L), levels = c("(-Inf,0.25]", "(0.25, Inf]"), class = "factor"), - forecast_date = structure(c(18992, 18992, 18992, 18992, 18992, - 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, - 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, - 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, - 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, - 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, - 18992, 18992, 18992), class = "Date"), target_date = structure(c(18999, - 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, - 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, - 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, - 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, - 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, - 18999, 18999, 18999, 18999, 18999, 18999, 18999), class = "Date")), row.names = c(NA, - -53L), class = c("tbl_df", "tbl", "data.frame")) + 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 2L, + 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L), levels = c("(-Inf,0.25]", + "(0.25, Inf]"), class = "factor"), forecast_date = structure(c(18992, + 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, + 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, + 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, + 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, + 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, 18992, + 18992, 18992, 18992, 18992, 18992), class = "Date"), target_date = structure(c(18999, + 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, + 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, + 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, + 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, + 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, 18999, + 18999, 18999, 18999, 18999, 18999), class = "Date")), row.names = c(NA, + -51L), class = c("tbl_df", "tbl", "data.frame")) --- structure(list(geo_value = c("ak", "al", "ar", "az", "ca", "co", - "ct", "dc", "de", "fl", "ga", "gu", "hi", "ia", "id", "il", "in", - "ks", "ky", "la", "ma", "me", "mi", "mn", "mo", "mp", "ms", "mt", - "nc", "nd", "ne", "nh", "nj", "nm", "nv", "ny", "oh", "ok", "or", - "pa", "pr", "ri", "sc", "sd", "tn", "tx", "ut", "va", "vt", "wa", - "wi", "wv", "wy"), .pred_class = structure(c(1L, 1L, 1L, 1L, - 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, + "ct", "dc", "de", "fl", "ga", "hi", "ia", "id", "il", "in", "ks", + "ky", "la", "ma", "me", "mi", "mn", "mo", "ms", "mt", "nc", "nd", + "ne", "nh", "nj", "nm", "nv", "ny", "oh", "ok", "or", "pa", "pr", + "ri", "sc", "sd", "tn", "tx", "ut", "va", "vt", "wa", "wi", "wv", + "wy"), .pred_class = structure(c(1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, - 1L), levels = c("(-Inf,0.25]", "(0.25, Inf]"), class = "factor"), - forecast_date = structure(c(18994, 18994, 18994, 18994, 18994, - 18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994, - 18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994, - 18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994, - 18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994, - 18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994, - 18994, 18994, 18994), class = "Date"), target_date = structure(c(19001, - 19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001, - 19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001, - 19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001, - 19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001, - 19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001, - 19001, 19001, 19001, 19001, 19001, 19001, 19001), class = "Date")), row.names = c(NA, - -53L), class = c("tbl_df", "tbl", "data.frame")) + 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L), levels = c("(-Inf,0.25]", + "(0.25, Inf]"), class = "factor"), forecast_date = structure(c(18994, + 18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994, + 18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994, + 18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994, + 18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994, + 18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994, 18994, + 18994, 18994, 18994, 18994, 18994), class = "Date"), target_date = structure(c(19001, + 19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001, + 19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001, + 19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001, + 19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001, + 19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001, 19001, + 19001, 19001, 19001, 19001, 19001), class = "Date")), row.names = c(NA, + -51L), class = c("tbl_df", "tbl", "data.frame")) diff --git a/tests/testthat/_snaps/step_growth_rate.md b/tests/testthat/_snaps/step_growth_rate.md index 0912977e..409aa197 100644 --- a/tests/testthat/_snaps/step_growth_rate.md +++ b/tests/testthat/_snaps/step_growth_rate.md @@ -89,19 +89,18 @@ --- Code - step_growth_rate(r, value, skip = 1) + step_growth_rate(r, value, na_rm = 1) Condition Error in `step_growth_rate()`: - ! `skip` must be a scalar of type . + ! `na_rm` must be a scalar of type . --- Code - step_growth_rate(r, value, additional_gr_args_list = 1:5) + step_growth_rate(r, value, skip = 1) Condition Error in `step_growth_rate()`: - ! `additional_gr_args_list` must be a . - i See `?epiprocess::growth_rate` for available options. + ! `skip` must be a scalar of type . --- diff --git a/tests/testthat/test-snapshots.R b/tests/testthat/test-snapshots.R index 3aecfd6b..8b2ae838 100644 --- a/tests/testthat/test-snapshots.R +++ b/tests/testthat/test-snapshots.R @@ -146,16 +146,18 @@ test_that("arx_forecaster output format snapshots", { }) test_that("arx_classifier snapshots", { - arc1 <- arx_classifier( - covid_case_death_rates %>% + train <- covid_case_death_rates %>% + filter(geo_value %nin% c("as", "gu", "mp", "vi")) + expect_warning(arc1 <- arx_classifier( + train %>% dplyr::filter(time_value >= as.Date("2021-11-01")), "death_rate", c("case_rate", "death_rate") - ) + ), "fitted probabilities numerically") expect_snapshot_tibble(arc1$predictions) - max_date <- covid_case_death_rates$time_value %>% max() + max_date <- train$time_value %>% max() arc2 <- arx_classifier( - covid_case_death_rates %>% + train %>% dplyr::filter(time_value >= as.Date("2021-11-01")), "death_rate", c("case_rate", "death_rate"), @@ -164,7 +166,7 @@ test_that("arx_classifier snapshots", { expect_snapshot_tibble(arc2$predictions) expect_error( arc3 <- arx_classifier( - covid_case_death_rates %>% + train %>% dplyr::filter(time_value >= as.Date("2021-11-01")), "death_rate", c("case_rate", "death_rate"), @@ -174,7 +176,7 @@ test_that("arx_classifier snapshots", { ) expect_error( arc4 <- arx_classifier( - covid_case_death_rates %>% + train %>% dplyr::filter(time_value >= as.Date("2021-11-01")), "death_rate", c("case_rate", "death_rate"), diff --git a/tests/testthat/test-step_growth_rate.R b/tests/testthat/test-step_growth_rate.R index f2845d81..dc8f04c9 100644 --- a/tests/testthat/test-step_growth_rate.R +++ b/tests/testthat/test-step_growth_rate.R @@ -16,8 +16,8 @@ test_that("step_growth_rate validates arguments", { expect_snapshot(error = TRUE, step_growth_rate(r, value, prefix = 1)) expect_snapshot(error = TRUE, step_growth_rate(r, value, id = 1)) expect_snapshot(error = TRUE, step_growth_rate(r, value, log_scale = 1)) + expect_snapshot(error = TRUE, step_growth_rate(r, value, na_rm = 1)) expect_snapshot(error = TRUE, step_growth_rate(r, value, skip = 1)) - expect_snapshot(error = TRUE, step_growth_rate(r, value, additional_gr_args_list = 1:5)) expect_snapshot(error = TRUE, step_growth_rate(r, value, replace_Inf = "c")) expect_snapshot(error = TRUE, step_growth_rate(r, value, replace_Inf = c(1, 2))) expect_silent(step_growth_rate(r, value, replace_Inf = NULL))