Skip to content

240 quantile pivot #241

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Oct 5, 2023
Merged
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
@@ -12,3 +12,4 @@
^musings$
^data-raw$
^vignettes/articles$
^.git-blame-ignore-revs$
4 changes: 2 additions & 2 deletions .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
@@ -4,9 +4,9 @@
# Created with usethis + edited to use API key.
on:
push:
branches: [main, master]
branches: [main, master, v0.0.6]
pull_request:
branches: [main, master]
branches: [main, master, v0.0.6]

name: R-CMD-check

1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -32,6 +32,7 @@ Imports:
generics,
glue,
hardhat (>= 1.3.0),
lifecycle,
magrittr,
methods,
quantreg,
10 changes: 9 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
@@ -52,6 +52,7 @@ S3method(print,alist)
S3method(print,arx_class)
S3method(print,arx_fcast)
S3method(print,canned_epipred)
S3method(print,cdc_baseline_fcast)
S3method(print,epi_workflow)
S3method(print,flat_fcast)
S3method(print,flatline)
@@ -79,6 +80,7 @@ S3method(residuals,flatline)
S3method(run_mold,default_epi_recipe_blueprint)
S3method(slather,layer_add_forecast_date)
S3method(slather,layer_add_target_date)
S3method(slather,layer_cdc_flatline_quantiles)
S3method(slather,layer_naomit)
S3method(slather,layer_point_from_distn)
S3method(slather,layer_population_scaling)
@@ -106,6 +108,8 @@ export(arx_classifier)
export(arx_fcast_epi_workflow)
export(arx_forecaster)
export(bake)
export(cdc_baseline_args_list)
export(cdc_baseline_forecaster)
export(create_layer)
export(default_epi_recipe_blueprint)
export(detect_layer)
@@ -131,6 +135,7 @@ export(is_layer)
export(layer)
export(layer_add_forecast_date)
export(layer_add_target_date)
export(layer_cdc_flatline_quantiles)
export(layer_naomit)
export(layer_point_from_distn)
export(layer_population_scaling)
@@ -143,7 +148,8 @@ export(layer_unnest)
export(nested_quantiles)
export(new_default_epi_recipe_blueprint)
export(new_epi_recipe_blueprint)
export(pivot_quantiles)
export(pivot_quantiles_longer)
export(pivot_quantiles_wider)
export(prep)
export(quantile_reg)
export(remove_frosting)
@@ -167,6 +173,7 @@ importFrom(generics,augment)
importFrom(generics,fit)
importFrom(hardhat,refresh_blueprint)
importFrom(hardhat,run_mold)
importFrom(lifecycle,deprecated)
importFrom(magrittr,"%>%")
importFrom(methods,is)
importFrom(quantreg,rq)
@@ -181,6 +188,7 @@ importFrom(rlang,caller_env)
importFrom(rlang,is_empty)
importFrom(rlang,is_null)
importFrom(rlang,quos)
importFrom(smoothqr,smooth_qr)
importFrom(stats,as.formula)
importFrom(stats,family)
importFrom(stats,lm)
3 changes: 2 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -7,7 +7,8 @@
* canned forecasters get a class
* fixed quantile bug in `flatline_forecaster()`
* add functionality to output the unfit workflow from the canned forecasters
* add `pivot_quantiles()` for easier plotting
* add `pivot_quantiles_wider()` for easier plotting
* add complement `pivot_quantiles_longer()`


# epipredict 0.0.4
228 changes: 228 additions & 0 deletions R/cdc_baseline_forecaster.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
#' Predict the future with the most recent value
#'
#' This is a simple forecasting model for
#' [epiprocess::epi_df] data. It uses the most recent observation as the
#' forecast for any future date, and produces intervals by shuffling the quantiles
#' of the residuals of such a "flatline" forecast and incrementing these
#' forward over all available training data.
#'
#' By default, the predictive intervals are computed separately for each
#' combination of `geo_value` in the `epi_data` argument.
#'
#' This forecaster is meant to produce exactly the CDC Baseline used for
#' [COVID19ForecastHub](https://covid19forecasthub.org)
#'
#' @param epi_data An [epiprocess::epi_df]
#' @param outcome A scalar character for the column name we wish to predict.
#' @param args_list A list of additional arguments as created by the
#' [cdc_baseline_args_list()] constructor function.
#'
#' @return A data frame of point and interval forecasts at for all
#' aheads (unique horizons) for each unique combination of `key_vars`.
#' @export
#'
#' @examples
#' library(dplyr)
#' weekly_deaths <- case_death_rate_subset %>%
#' select(geo_value, time_value, death_rate) %>%
#' left_join(state_census %>% select(pop, abbr), by = c("geo_value" = "abbr")) %>%
#' mutate(deaths = pmax(death_rate / 1e5 * pop, 0)) %>%
#' select(-pop, -death_rate) %>%
#' group_by(geo_value) %>%
#' epi_slide(~ sum(.$deaths), before = 6, new_col_name = "deaths") %>%
#' ungroup() %>%
#' filter(weekdays(time_value) == "Saturday")
#'
#' cdc <- cdc_baseline_forecaster(weekly_deaths, "deaths")
#' preds <- pivot_quantiles_wider(cdc$predictions, .pred_distn)
#'
#' if (require(ggplot2)) {
#' forecast_date <- unique(preds$forecast_date)
#' four_states <- c("ca", "pa", "wa", "ny")
#' preds %>%
#' filter(geo_value %in% four_states) %>%
#' ggplot(aes(target_date)) +
#' geom_ribbon(aes(ymin = `0.1`, ymax = `0.9`), fill = blues9[3]) +
#' geom_ribbon(aes(ymin = `0.25`, ymax = `0.75`), fill = blues9[6]) +
#' geom_line(aes(y = .pred), color = "orange") +
#' geom_line(
#' data = weekly_deaths %>% filter(geo_value %in% four_states),
#' aes(x = time_value, y = deaths)
#' ) +
#' scale_x_date(limits = c(forecast_date - 90, forecast_date + 30)) +
#' labs(x = "Date", y = "Weekly deaths") +
#' facet_wrap(~geo_value, scales = "free_y") +
#' theme_bw() +
#' geom_vline(xintercept = forecast_date)
#' }
cdc_baseline_forecaster <- function(
epi_data,
outcome,
args_list = cdc_baseline_args_list()) {
validate_forecaster_inputs(epi_data, outcome, "time_value")
if (!inherits(args_list, c("cdc_flat_fcast", "alist"))) {
cli_stop("args_list was not created using `cdc_baseline_args_list().")
}
keys <- epi_keys(epi_data)
ek <- kill_time_value(keys)
outcome <- rlang::sym(outcome)


r <- epi_recipe(epi_data) %>%
step_epi_ahead(!!outcome, ahead = args_list$data_frequency, skip = TRUE) %>%
recipes::update_role(!!outcome, new_role = "predictor") %>%
recipes::add_role(tidyselect::all_of(keys), new_role = "predictor") %>%
step_training_window(n_recent = args_list$n_training)

forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
# target_date <- args_list$target_date %||% forecast_date + args_list$ahead


latest <- get_test_data(
epi_recipe(epi_data), epi_data, TRUE, args_list$nafill_buffer,
forecast_date
)

f <- frosting() %>%
layer_predict() %>%
layer_cdc_flatline_quantiles(
aheads = args_list$aheads,
quantile_levels = args_list$quantile_levels,
nsims = args_list$nsims,
by_key = args_list$quantile_by_key,
symmetrize = args_list$symmetrize,
nonneg = args_list$nonneg
) %>%
layer_add_forecast_date(forecast_date = forecast_date) %>%
layer_unnest(.pred_distn_all)
# layer_add_target_date(target_date = target_date)
if (args_list$nonneg) f <- layer_threshold(f, ".pred")

eng <- parsnip::linear_reg() %>% parsnip::set_engine("flatline")

wf <- epi_workflow(r, eng, f)
wf <- generics::fit(wf, epi_data)
preds <- suppressWarnings(predict(wf, new_data = latest)) %>%
tibble::as_tibble() %>%
dplyr::select(-time_value) %>%
dplyr::mutate(target_date = forecast_date + ahead * args_list$data_frequency)

structure(
list(
predictions = preds,
epi_workflow = wf,
metadata = list(
training = attr(epi_data, "metadata"),
forecast_created = Sys.time()
)
),
class = c("cdc_baseline_fcast", "canned_epipred")
)
}



#' CDC baseline forecaster argument constructor
#'
#' Constructs a list of arguments for [cdc_baseline_forecaster()].
#'
#' @inheritParams arx_args_list
#' @param data_frequency Integer or string. This describes the frequency of the
#' input `epi_df`. For typical FluSight forecasts, this would be `"1 week"`.
#' Allowable arguments are integers (taken to mean numbers of days) or a
#' string like `"7 days"` or `"2 weeks"`. Currently, all other periods
#' (other than days or weeks) result in an error.
#' @param aheads Integer vector. Unlike [arx_forecaster()], this doesn't have
#' any effect on the predicted values.
#' Predictions are always the most recent observation. This determines the
#' set of prediction horizons for [layer_cdc_flatline_quantiles()]`. It interacts
#' with the `data_frequency` argument. So, for example, if the data is daily
#' and you want forecasts for 1:4 days ahead, then you would use `1:4`. However,
#' if you want one-week predictions, you would set this as `c(7, 14, 21, 28)`.
#' But if `data_frequency` is `"1 week"`, then you would set it as `1:4`.
#' @param quantile_levels Vector or `NULL`. A vector of probabilities to produce
#' prediction intervals. These are created by computing the quantiles of
#' training residuals. A `NULL` value will result in point forecasts only.
#' @param nsims Positive integer. The number of draws from the empirical CDF.
#' These samples are spaced evenly on the (0, 1) scale, F_X(x) resulting
#' in linear interpolation on the X scale. This is achieved with
#' [stats::quantile()] Type 7 (the default for that function).
#' @param nonneg Logical. Force all predictive intervals be non-negative.
#' Because non-negativity is forced _before_ propagating forward, this
#' has slightly different behaviour than would occur if using
#' [layer_threshold()].
#'
#' @return A list containing updated parameter choices with class `cdc_flat_fcast`.
#' @export
#'
#' @examples
#' cdc_baseline_args_list()
#' cdc_baseline_args_list(symmetrize = FALSE)
#' cdc_baseline_args_list(quantile_levels = c(.1, .3, .7, .9), n_training = 120)
cdc_baseline_args_list <- function(
data_frequency = "1 week",
aheads = 1:4,
n_training = Inf,
forecast_date = NULL,
quantile_levels = c(.01, .025, 1:19 / 20, .975, .99),
nsims = 1e3L,
symmetrize = TRUE,
nonneg = TRUE,
quantile_by_key = "geo_value",
nafill_buffer = Inf) {
arg_is_scalar(n_training, nsims, data_frequency)
data_frequency <- parse_period(data_frequency)
arg_is_pos_int(data_frequency)
arg_is_chr(quantile_by_key, allow_empty = TRUE)
arg_is_scalar(forecast_date, allow_null = TRUE)
arg_is_date(forecast_date, allow_null = TRUE)
arg_is_nonneg_int(aheads, nsims)
arg_is_lgl(symmetrize, nonneg)
arg_is_probabilities(quantile_levels, allow_null = TRUE)
arg_is_pos(n_training)
if (is.finite(n_training)) arg_is_pos_int(n_training)
if (is.finite(nafill_buffer)) arg_is_pos_int(nafill_buffer, allow_null = TRUE)

structure(
enlist(
data_frequency,
aheads,
n_training,
forecast_date,
quantile_levels,
nsims,
symmetrize,
nonneg,
quantile_by_key,
nafill_buffer
),
class = c("cdc_baseline_fcast", "alist")
)
}

#' @export
print.cdc_baseline_fcast <- function(x, ...) {
name <- "CDC Baseline"
NextMethod(name = name, ...)
}

parse_period <- function(x) {
arg_is_scalar(x)
if (is.character(x)) {
x <- unlist(strsplit(x, " "))
if (length(x) == 1L) x <- as.numeric(x)
if (length(x) == 2L) {
mult <- substr(x[2], 1, 3)
mult <- switch(
mult,
day = 1L,
wee = 7L,
cli::cli_abort("incompatible timespan in `aheads`.")
)
x <- as.numeric(x[1]) * mult
}
if (length(x) > 2L) cli::cli_abort("incompatible timespan in `aheads`.")
}
stopifnot(rlang::is_integerish(x))
as.integer(x)
}
5 changes: 5 additions & 0 deletions R/compat-purrr.R
Original file line number Diff line number Diff line change
@@ -32,6 +32,11 @@ map_chr <- function(.x, .f, ...) {
.rlang_purrr_map_mold(.x, .f, character(1), ...)
}

map_vec <- function(.x, .f, ...) {
out <- map(.x, .f, ...)
vctrs::list_unchop(out)
}

map_dfr <- function(.x, .f, ..., .id = NULL) {
.f <- rlang::as_function(.f, env = rlang::global_env())
res <- map(.x, .f, ...)
Loading