Skip to content

Cdc baseline #245

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
8a1e2d6
start CDC baseline layer
dajmcdon Sep 14, 2023
cea1599
upgrade enframer
dajmcdon Sep 22, 2023
d606741
functions, remains to check validity
dajmcdon Sep 23, 2023
7294c00
correct symmetrization, enhance documentation of the "ahead" param in…
dajmcdon Sep 24, 2023
f18e88f
better defaults, cli, pred is scalar in propagate_samples
dajmcdon Sep 24, 2023
d6a28f3
redocument
dajmcdon Sep 24, 2023
237ec50
run styler
dajmcdon Sep 24, 2023
c13b83e
redocument after styling
dajmcdon Sep 24, 2023
16f6c2c
example plotting with ggplot2 handled correctly
dajmcdon Sep 25, 2023
c9b4667
working cdc baseline
dajmcdon Oct 4, 2023
d59a691
add cdc baseline to pkgdown
dajmcdon Oct 4, 2023
21b4c85
local checks pass
dajmcdon Oct 4, 2023
1f58e67
Fix incomplete `symmetrize` + document it
brookslogan Oct 4, 2023
fe31a79
`document()`
brookslogan Oct 4, 2023
93fac1a
Fix death rate 7dav -> weekly sum conversion in example
brookslogan Oct 4, 2023
0516836
Copyediting, roxygen link styling, formatting
brookslogan Oct 4, 2023
0905ba4
"Logical" -> "Scalar logical" as appropriate to match rest of docs
brookslogan Oct 4, 2023
6f14e6a
add formatter to the correct branch
dajmcdon Oct 5, 2023
a33baae
Merge branch 'cdc-baseline' of https://github.com/cmu-delphi/epipredi…
dajmcdon Oct 5, 2023
b1c34cb
Speed up cdc baseline: `quantile(...., names = FALSE)`
brookslogan Oct 5, 2023
cf9a44e
CI: trying to change in a particular branch too
dsweber2 Oct 5, 2023
343add1
formatter works
dajmcdon Oct 5, 2023
b5fe624
local checks pass
dajmcdon Oct 5, 2023
bdbd3ee
Merge branch 'cdc-baseline' of https://github.com/cmu-delphi/epipredi…
dajmcdon Oct 5, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
^musings$
^data-raw$
^vignettes/articles$
^.git-blame-ignore-revs$
4 changes: 2 additions & 2 deletions .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
# Created with usethis + edited to use API key.
on:
push:
branches: [main, master]
branches: [main, master, v0.0.6]
pull_request:
branches: [main, master]
branches: [main, master, v0.0.6]

name: R-CMD-check

Expand Down
9 changes: 9 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ S3method(extrapolate_quantiles,dist_default)
S3method(extrapolate_quantiles,dist_quantiles)
S3method(extrapolate_quantiles,distribution)
S3method(fit,epi_workflow)
S3method(flusight_hub_formatter,canned_epipred)
S3method(flusight_hub_formatter,data.frame)
S3method(format,dist_quantiles)
S3method(is.na,dist_quantiles)
S3method(is.na,distribution)
Expand All @@ -52,6 +54,7 @@ S3method(print,alist)
S3method(print,arx_class)
S3method(print,arx_fcast)
S3method(print,canned_epipred)
S3method(print,cdc_baseline_fcast)
S3method(print,epi_workflow)
S3method(print,flat_fcast)
S3method(print,flatline)
Expand Down Expand Up @@ -79,6 +82,7 @@ S3method(residuals,flatline)
S3method(run_mold,default_epi_recipe_blueprint)
S3method(slather,layer_add_forecast_date)
S3method(slather,layer_add_target_date)
S3method(slather,layer_cdc_flatline_quantiles)
S3method(slather,layer_naomit)
S3method(slather,layer_point_from_distn)
S3method(slather,layer_population_scaling)
Expand Down Expand Up @@ -106,6 +110,8 @@ export(arx_classifier)
export(arx_fcast_epi_workflow)
export(arx_forecaster)
export(bake)
export(cdc_baseline_args_list)
export(cdc_baseline_forecaster)
export(create_layer)
export(default_epi_recipe_blueprint)
export(detect_layer)
Expand All @@ -122,6 +128,7 @@ export(fit)
export(flatline)
export(flatline_args_list)
export(flatline_forecaster)
export(flusight_hub_formatter)
export(frosting)
export(get_test_data)
export(grab_names)
Expand All @@ -131,6 +138,7 @@ export(is_layer)
export(layer)
export(layer_add_forecast_date)
export(layer_add_target_date)
export(layer_cdc_flatline_quantiles)
export(layer_naomit)
export(layer_point_from_distn)
export(layer_population_scaling)
Expand Down Expand Up @@ -181,6 +189,7 @@ importFrom(rlang,caller_env)
importFrom(rlang,is_empty)
importFrom(rlang,is_null)
importFrom(rlang,quos)
importFrom(smoothqr,smooth_qr)
importFrom(stats,as.formula)
importFrom(stats,family)
importFrom(stats,lm)
Expand Down
228 changes: 228 additions & 0 deletions R/cdc_baseline_forecaster.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
#' Predict the future with the most recent value
#'
#' This is a simple forecasting model for
#' [epiprocess::epi_df] data. It uses the most recent observation as the
#' forecast for any future date, and produces intervals by shuffling the quantiles
#' of the residuals of such a "flatline" forecast and incrementing these
#' forward over all available training data.
#'
#' By default, the predictive intervals are computed separately for each
#' combination of `geo_value` in the `epi_data` argument.
#'
#' This forecaster is meant to produce exactly the CDC Baseline used for
#' [COVID19ForecastHub](https://covid19forecasthub.org)
#'
#' @param epi_data An [`epiprocess::epi_df`]
#' @param outcome A scalar character for the column name we wish to predict.
#' @param args_list A list of additional arguments as created by the
#' [cdc_baseline_args_list()] constructor function.
#'
#' @return A data frame of point and interval forecasts for all aheads (unique
#' horizons) for each unique combination of `key_vars`.
#' @export
#'
#' @examples
#' library(dplyr)
#' weekly_deaths <- case_death_rate_subset %>%
#' select(geo_value, time_value, death_rate) %>%
#' left_join(state_census %>% select(pop, abbr), by = c("geo_value" = "abbr")) %>%
#' mutate(deaths = pmax(death_rate / 1e5 * pop * 7, 0)) %>%
#' select(-pop, -death_rate) %>%
#' group_by(geo_value) %>%
#' epi_slide(~ sum(.$deaths), before = 6, new_col_name = "deaths") %>%
#' ungroup() %>%
#' filter(weekdays(time_value) == "Saturday")
#'
#' cdc <- cdc_baseline_forecaster(weekly_deaths, "deaths")
#' preds <- pivot_quantiles(cdc$predictions, .pred_distn)
#'
#' if (require(ggplot2)) {
#' forecast_date <- unique(preds$forecast_date)
#' four_states <- c("ca", "pa", "wa", "ny")
#' preds %>%
#' filter(geo_value %in% four_states) %>%
#' ggplot(aes(target_date)) +
#' geom_ribbon(aes(ymin = `0.1`, ymax = `0.9`), fill = blues9[3]) +
#' geom_ribbon(aes(ymin = `0.25`, ymax = `0.75`), fill = blues9[6]) +
#' geom_line(aes(y = .pred), color = "orange") +
#' geom_line(
#' data = weekly_deaths %>% filter(geo_value %in% four_states),
#' aes(x = time_value, y = deaths)
#' ) +
#' scale_x_date(limits = c(forecast_date - 90, forecast_date + 30)) +
#' labs(x = "Date", y = "Weekly deaths") +
#' facet_wrap(~geo_value, scales = "free_y") +
#' theme_bw() +
#' geom_vline(xintercept = forecast_date)
#' }
cdc_baseline_forecaster <- function(
epi_data,
outcome,
args_list = cdc_baseline_args_list()) {
validate_forecaster_inputs(epi_data, outcome, "time_value")
if (!inherits(args_list, c("cdc_flat_fcast", "alist"))) {
cli_stop("args_list was not created using `cdc_baseline_args_list().")
}
keys <- epi_keys(epi_data)
ek <- kill_time_value(keys)
outcome <- rlang::sym(outcome)


r <- epi_recipe(epi_data) %>%
step_epi_ahead(!!outcome, ahead = args_list$data_frequency, skip = TRUE) %>%
recipes::update_role(!!outcome, new_role = "predictor") %>%
recipes::add_role(tidyselect::all_of(keys), new_role = "predictor") %>%
step_training_window(n_recent = args_list$n_training)

forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
# target_date <- args_list$target_date %||% forecast_date + args_list$ahead


latest <- get_test_data(
epi_recipe(epi_data), epi_data, TRUE, args_list$nafill_buffer,
forecast_date
)

f <- frosting() %>%
layer_predict() %>%
layer_cdc_flatline_quantiles(
aheads = args_list$aheads,
quantile_levels = args_list$quantile_levels,
nsims = args_list$nsims,
by_key = args_list$quantile_by_key,
symmetrize = args_list$symmetrize,
nonneg = args_list$nonneg
) %>%
layer_add_forecast_date(forecast_date = forecast_date) %>%
layer_unnest(.pred_distn_all)
# layer_add_target_date(target_date = target_date)
if (args_list$nonneg) f <- layer_threshold(f, ".pred")

eng <- parsnip::linear_reg() %>% parsnip::set_engine("flatline")

wf <- epi_workflow(r, eng, f)
wf <- generics::fit(wf, epi_data)
preds <- suppressWarnings(predict(wf, new_data = latest)) %>%
tibble::as_tibble() %>%
dplyr::select(-time_value) %>%
dplyr::mutate(target_date = forecast_date + ahead * args_list$data_frequency)

structure(
list(
predictions = preds,
epi_workflow = wf,
metadata = list(
training = attr(epi_data, "metadata"),
forecast_created = Sys.time()
)
),
class = c("cdc_baseline_fcast", "canned_epipred")
)
}



#' CDC baseline forecaster argument constructor
#'
#' Constructs a list of arguments for [cdc_baseline_forecaster()].
#'
#' @inheritParams arx_args_list
#' @param data_frequency Integer or string. This describes the frequency of the
#' input `epi_df`. For typical FluSight forecasts, this would be `"1 week"`.
#' Allowable arguments are integers (taken to mean numbers of days) or a
#' string like `"7 days"` or `"2 weeks"`. Currently, all other periods
#' (other than days or weeks) result in an error.
#' @param aheads Integer vector. Unlike [arx_forecaster()], this doesn't have
#' any effect on the predicted values.
#' Predictions are always the most recent observation. This determines the
#' set of prediction horizons for [layer_cdc_flatline_quantiles()]`. It interacts
#' with the `data_frequency` argument. So, for example, if the data is daily
#' and you want forecasts for 1:4 days ahead, then you would use `1:4`. However,
#' if you want one-week predictions, you would set this as `c(7, 14, 21, 28)`.
#' But if `data_frequency` is `"1 week"`, then you would set it as `1:4`.
#' @param quantile_levels Vector or `NULL`. A vector of probabilities to produce
#' prediction intervals. These are created by computing the quantiles of
#' training residuals. A `NULL` value will result in point forecasts only.
#' @param nsims Positive integer. The number of draws from the empirical CDF.
#' These samples are spaced evenly on the (0, 1) scale, F_X(x) resulting
#' in linear interpolation on the X scale. This is achieved with
#' [stats::quantile()] Type 7 (the default for that function).
#' @param nonneg Logical. Force all predictive intervals be non-negative.
#' Because non-negativity is forced _before_ propagating forward, this
#' has slightly different behaviour than would occur if using
#' [layer_threshold()].
#'
#' @return A list containing updated parameter choices with class `cdc_flat_fcast`.
#' @export
#'
#' @examples
#' cdc_baseline_args_list()
#' cdc_baseline_args_list(symmetrize = FALSE)
#' cdc_baseline_args_list(quantile_levels = c(.1, .3, .7, .9), n_training = 120)
cdc_baseline_args_list <- function(
data_frequency = "1 week",
aheads = 1:4,
n_training = Inf,
forecast_date = NULL,
quantile_levels = c(.01, .025, 1:19 / 20, .975, .99),
nsims = 1e3L,
symmetrize = TRUE,
nonneg = TRUE,
quantile_by_key = "geo_value",
nafill_buffer = Inf) {
arg_is_scalar(n_training, nsims, data_frequency)
data_frequency <- parse_period(data_frequency)
arg_is_pos_int(data_frequency)
arg_is_chr(quantile_by_key, allow_empty = TRUE)
arg_is_scalar(forecast_date, allow_null = TRUE)
arg_is_date(forecast_date, allow_null = TRUE)
arg_is_nonneg_int(aheads, nsims)
arg_is_lgl(symmetrize, nonneg)
arg_is_probabilities(quantile_levels, allow_null = TRUE)
arg_is_pos(n_training)
if (is.finite(n_training)) arg_is_pos_int(n_training)
if (is.finite(nafill_buffer)) arg_is_pos_int(nafill_buffer, allow_null = TRUE)

structure(
enlist(
data_frequency,
aheads,
n_training,
forecast_date,
quantile_levels,
nsims,
symmetrize,
nonneg,
quantile_by_key,
nafill_buffer
),
class = c("cdc_baseline_fcast", "alist")
)
}

#' @export
print.cdc_baseline_fcast <- function(x, ...) {
name <- "CDC Baseline"
NextMethod(name = name, ...)
}

parse_period <- function(x) {
arg_is_scalar(x)
if (is.character(x)) {
x <- unlist(strsplit(x, " "))
if (length(x) == 1L) x <- as.numeric(x)
if (length(x) == 2L) {
mult <- substr(x[2], 1, 3)
mult <- switch(
mult,
day = 1L,
wee = 7L,
cli::cli_abort("incompatible timespan in `aheads`.")
)
x <- as.numeric(x[1]) * mult
}
if (length(x) > 2L) cli::cli_abort("incompatible timespan in `aheads`.")
}
stopifnot(rlang::is_integerish(x))
as.integer(x)
}
5 changes: 5 additions & 0 deletions R/compat-purrr.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ map_chr <- function(.x, .f, ...) {
.rlang_purrr_map_mold(.x, .f, character(1), ...)
}

map_vec <- function(.x, .f, ...) {
out <- map(.x, .f, ...)
vctrs::list_unchop(out)
}

map_dfr <- function(.x, .f, ..., .id = NULL) {
.f <- rlang::as_function(.f, env = rlang::global_env())
res <- map(.x, .f, ...)
Expand Down
6 changes: 6 additions & 0 deletions R/flatline_forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ flatline_forecaster <- function(
#' Constructs a list of arguments for [flatline_forecaster()].
#'
#' @inheritParams arx_args_list
#' @param ahead Integer. Unlike [arx_forecaster()], this doesn't have any effect
#' on the predicted values. Predictions are always the most recent observation.
#' However, this _does_ impact the residuals stored in the object. Residuals
#' are calculated based on this number to mimic how badly you would have done.
#' So for example, `ahead = 7` will create residuals by comparing values
#' 7 days apart.
#'
#' @return A list containing updated parameter choices with class `flatline_alist`.
#' @export
Expand Down
Loading