Skip to content

Commit 1e7ec1e

Browse files
committed
feat: add forecast method #293
1 parent 6b488a4 commit 1e7ec1e

File tree

5 files changed

+69
-1
lines changed

5 files changed

+69
-1
lines changed

NAMESPACE

+1
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ export(flatline)
152152
export(flatline_args_list)
153153
export(flatline_forecaster)
154154
export(flusight_hub_formatter)
155+
export(forecast)
155156
export(frosting)
156157
export(get_test_data)
157158
export(grab_names)

R/epi_workflow.R

+5-1
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,11 @@ update_model.epi_workflow <- function(x, spec, ..., formula = NULL) {
197197
#'
198198
#' @export
199199
fit.epi_workflow <- function(object, data, ..., control = workflows::control_workflow()) {
200-
object$fit$meta <- list(max_time_value = max(data$time_value), as_of = attributes(data)$metadata$as_of)
200+
object$fit$meta <- list(
201+
max_time_value = max(data$time_value),
202+
as_of = attributes(data)$metadata$as_of,
203+
train_data = data
204+
)
201205

202206
NextMethod()
203207
}

R/forecast.R

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#' Produce forecast from an epi workflow
2+
#' @param epi_workflow An epi workflow
3+
#'
4+
#' @return A forecast tibble.
5+
#'
6+
#' @export
7+
forecast <- function(epi_workflow, fill_locf = FALSE, n_recent = NULL, forecast_date = NULL) {
8+
# Find data inside the epi_workflow
9+
if (!epi_workflow$trained) {
10+
cli_abort("The epi_workflow is not trained.")
11+
}
12+
13+
test_data <- get_test_data(
14+
hardhat::extract_preprocessor(epi_workflow),
15+
epi_workflow$fit$meta$train_data,
16+
fill_locf = fill_locf,
17+
n_recent = n_recent %||% Inf,
18+
forecast_date = forecast_date %||% max(epi_workflow$fit$meta$train_data$time_value)
19+
)
20+
21+
predict(epi_workflow, new_data = test_data)
22+
}

man/forecast.Rd

+22
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-forecast.R

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
test_that("forecast method works", {
2+
jhu <- case_death_rate_subset %>%
3+
filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny"))
4+
r <- epi_recipe(jhu) %>%
5+
step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
6+
step_epi_ahead(death_rate, ahead = 7) %>%
7+
step_epi_naomit()
8+
wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu)
9+
10+
latest <- get_test_data(
11+
hardhat::extract_preprocessor(wf),
12+
jhu
13+
)
14+
15+
expect_equal(
16+
forecast(wf),
17+
predict(wf, new_data = latest)
18+
)
19+
})

0 commit comments

Comments
 (0)