Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

estimate r from initial R #923

Merged
merged 34 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
c7d48a1
estimate r from initial R
sbfnk Jan 10, 2025
07b2b51
remove unneeded argument
sbfnk Jan 10, 2025
06bd809
update tests
sbfnk Jan 10, 2025
5239e63
update naming to be more consistent
sbfnk Jan 10, 2025
a77ce56
don't use sundials solver
sbfnk Jan 10, 2025
b4f18b4
update snapshot
sbfnk Jan 10, 2025
8c9750f
estimate initial infections within model
sbfnk Jan 13, 2025
72340b9
update simulation snapshot
sbfnk Jan 13, 2025
33692c6
another snapshot update
sbfnk Jan 13, 2025
6931751
doc update
sbfnk Jan 13, 2025
1f8289d
update snapshot again
sbfnk Jan 13, 2025
998ef5c
try manual approach
sbfnk Jan 13, 2025
2a006a4
relax prior
sbfnk Jan 13, 2025
618f8bc
scale with early cases
sbfnk Jan 13, 2025
487f260
update snapshots
sbfnk Jan 13, 2025
8986fb1
fabs -> abs
sbfnk Jan 13, 2025
60d7cd5
pass cases
sbfnk Jan 13, 2025
9ed36e3
ensure mean init is >=1
sbfnk Jan 14, 2025
9beca9a
move to transformed data
sbfnk Jan 20, 2025
20e4f6f
add source
sbfnk Jan 20, 2025
27fa43d
update news item
sbfnk Jan 20, 2025
d3b63f3
fix simulations
sbfnk Jan 20, 2025
31ebdc9
update tests
sbfnk Jan 20, 2025
11534c7
update snapshots
sbfnk Jan 20, 2025
c62be58
temporarily remove additional repos
sbfnk Jan 21, 2025
9dd97d5
Revert "temporarily remove additional repos"
sbfnk Jan 21, 2025
d70e4e5
touchstone: don't upgrade
sbfnk Jan 21, 2025
7f48f9e
stabilise initial guess
sbfnk Jan 22, 2025
684f139
Revert "touchstone: don't upgrade"
sbfnk Jan 22, 2025
75628ff
update sim snapshot
sbfnk Jan 22, 2025
0449171
try max instead of correction
sbfnk Jan 22, 2025
dba0c2e
remove/rename
sbfnk Jan 22, 2025
6a36058
version 2: testing for speed
sbfnk Jan 23, 2025
35740e5
Revert "version 2: testing for speed"
sbfnk Jan 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
- All parameters have been changed to the new parameter interface. By @sbfnk in #871 and #890 and reviewed by @seabbs.
- The Gaussian Process lengthscale is now scaled internally by half the length of the time series. By @sbfnk in #890 and reviewed by #seabbs.
- A bug was fixed where `plot.dist_spec()` wasn't throwing an informative error due to an incomplete check for the max of the specified delay. By @jamesmbaazam in #858 and reviewed by @.
- Updated the early dynamics calculation to use the full linear model if available. Also changed the prior for initial infections to be approximately Poisson and the initial growth rate to the point estimate of the initial growth rate scaled linearly by the estimated initial infections term. By @sbfnk in #903 and reviewed by @seabbs and @SamuelBrand1
- Updated the early dynamics calculation to estimate growth from the initial reproduction number instead of a separate linear model. Also changed the prior calculation for initial infections to be a scaling factor of early case numbers adjusted by the growth estimate, instead a true number of initial infections. By @sbfnk in #923 (with initial exploration in #903) and reviewed by @seabbs and @SamuelBrand1.

## Package changes

Expand Down
57 changes: 1 addition & 56 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -440,54 +440,6 @@ create_forecast_data <- function(forecast = forecast_opts(), data) {
return(data)
}

#' Calculate prior infections and fit early growth
#'
#' @description Calculates the prior infections and growth rate based on the
#' first week's data.
#'
#' @param cases Numeric vector; the case counts from the input data.
#' @inheritParams create_stan_data
#' @return A list containing `initial_infections_estimate` and
#' `initial_growth_estimate`.
#' @keywords internal
estimate_early_dynamics <- function(cases, seeding_time) {
initial_period <- data.table::data.table(
confirm = cases[seq_len(min(7, seeding_time, length(cases)))],
t = seq_len(min(7, seeding_time, length(cases))) - 1
)[!is.na(confirm)]

# Calculate initial infections and growth estimate
if (seeding_time > 1 && nrow(initial_period) > 1) {
safe_lm <- purrr::safely(stats::lm)
log_linear_estimate <- safe_lm(log(confirm) ~ t, data = initial_period)[[1]]
initial_infections_estimate <- ifelse(
is.null(log_linear_estimate), 0, log_linear_estimate$coefficients[1]
)
initial_growth_estimate <- ifelse(
is.null(log_linear_estimate), 0, log_linear_estimate$coefficients[2]
)
} else {
initial_infections_estimate <- 0
initial_growth_estimate <- 0
}

# Calculate prior infections
if (initial_infections_estimate == 0) {
initial_infections_estimate <- log(
mean(initial_period$confirm, na.rm = TRUE)
)
if (is.na(initial_infections_estimate) ||
is.null(initial_infections_estimate)) {
initial_infections_estimate <- 0
}
}

return(list(
initial_infections_estimate = initial_infections_estimate,
initial_growth_estimate = initial_growth_estimate
))
}

#' Create Stan Data Required for estimate_infections
#'
#' @description`r lifecycle::badge("stable")`
Expand Down Expand Up @@ -553,11 +505,6 @@ create_stan_data <- function(data, seeding_time, rt, gp, obs, backcalc,
delay = stan_data$seeding_time, horizon = stan_data$horizon
)
)
# calculate prior infections and fit early growth
stan_data <- c(
stan_data,
estimate_early_dynamics(confirmed_cases, seeding_time)
)
# backcalculation settings
stan_data <- c(stan_data, create_backcalc_data(backcalc))
# gaussian process data
Expand Down Expand Up @@ -639,9 +586,7 @@ create_initial_conditions <- function(data) {
out$eta <- array(numeric(0))
}
if (data$estimate_r == 1) {
out$initial_infections <- array(
rnorm(1, data$initial_infections_estimate, 0.2)
)
out$initial_infections <- array(rnorm(1))
}

if (data$bp_n > 0) {
Expand Down
19 changes: 5 additions & 14 deletions R/simulate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,25 +105,14 @@ simulate_infections <- function(estimates, R, initial_infections,
if (missing(seeding_time)) {
seeding_time <- sum(max(generation_time))
}
if (seeding_time > 1) {
## estimate initial growth from initial reproduction number if seeding time
## is greater than 1
initial_growth <- (R$R[1] - 1) / mean(generation_time)
## adjust initial infections for initial exponential growth
log_initial_infections <- log(initial_infections) -
(seeding_time - 1) * initial_growth
} else {
initial_growth <- numeric(0)
log_initial_infections <- log(initial_infections)
}

data <- list(
n = 1,
t = nrow(R) + seeding_time,
seeding_time = seeding_time,
future_time = 0,
initial_infections = array(log_initial_infections, dim = c(1, 1)),
initial_growth = array(initial_growth, dim = c(1, length(initial_growth))),
initial_infections = array(log(initial_infections), dim = c(1, 1)),
initial_as_scale = 0,
R = array(R$R, dim = c(1, nrow(R))),
pop = pop
)
Expand Down Expand Up @@ -433,7 +422,9 @@ forecast_infections <- function(estimates,
draws <- map(draws, ~ as.matrix(.[nstart:nend, ]))

## prepare data for stan command
data <- c(list(n = dim(draws$R)[1]), draws, estimates$args)
data <- c(
list(n = dim(draws$R)[1], initial_as_scale = 1), draws, estimates$args
)

## allocate empty parameters
data <- allocate_empty(
Expand Down
2 changes: 0 additions & 2 deletions inst/stan/data/rt.stan
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
int estimate_r; // should the reproduction no be estimated (1 = yes)
real initial_infections_estimate; // point estimate of initial infections
real initial_growth_estimate; // point estimate of initial growth rate
int bp_n; // no of breakpoints (0 = no breakpoints)
array[t - seeding_time] int breakpoints; // when do breakpoints occur
int future_fixed; // is underlying future Rt assumed to be fixed
Expand Down
2 changes: 1 addition & 1 deletion inst/stan/data/simulation_rt.stan
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
array[n, 1] real initial_infections; // initial logged infections
array[n, seeding_time > 1 ? 1 : 0] real initial_growth; //initial growth
int initial_as_scale; // whether to interpret initial infections as scaling

matrix[n, t - seeding_time] R; // reproduction number
int pop; // susceptible population
Expand Down
19 changes: 10 additions & 9 deletions inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ transformed data {
delay_types_groups, delay_max, delay_np_pmf_groups
);
}

// initial infections scaling (on the log scale)
real initial_infections_guess = fmax(
0,
log(mean(head(cases, num_elements(cases) > 7 ? 7 : num_elements(cases))))
);
}

parameters {
Expand All @@ -60,12 +66,6 @@ transformed parameters {
vector[ot_h] reports; // estimated reported cases
vector[ot] obs_reports; // observed estimated reported cases
vector[estimate_r * (delay_type_max[gt_id] + 1)] gt_rev_pmf;
array[estimate_r && seeding_time > 1 ? 1 : 0] real initial_growth; // seed growth rate

if (num_elements(initial_growth) > 0) {
initial_growth[1] = initial_growth_estimate +
initial_infections_estimate - initial_infections[1];
}

// GP in noise - spectral densities
profile("update gp") {
Expand Down Expand Up @@ -108,8 +108,8 @@ transformed parameters {
params
);
infections = generate_infections(
R, seeding_time, gt_rev_pmf, initial_infections, initial_growth, pop,
future_time, obs_scale, frac_obs
R, seeding_time, gt_rev_pmf, initial_infections, pop,
future_time, obs_scale, frac_obs, 1
);
}
} else {
Expand Down Expand Up @@ -210,7 +210,8 @@ model {
// priors on Rt
profile("rt lp") {
rt_lp(
initial_infections, bp_effects, bp_sd, bp_n, initial_infections_estimate
initial_infections, bp_effects, bp_sd, bp_n,
cases, initial_infections_guess
);
}
}
Expand Down
25 changes: 15 additions & 10 deletions inst/stan/functions/infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,33 @@ real update_infectiousness(vector infections, vector gt_rev_pmf,
);
return(new_inf);
}

// generate infections by using Rt = Rt-1 * sum(reversed generation time pmf * infections)
vector generate_infections(vector oR, int uot, vector gt_rev_pmf,
array[] real initial_infections, array[] real initial_growth,
int pop, int ht, int obs_scale, real frac_obs) {
vector generate_infections(vector R, int uot, vector gt_rev_pmf,
array[] real initial_infections, int pop, int ht,
int obs_scale, real frac_obs, int initial_as_scale) {
// time indices and storage
int ot = num_elements(oR);
int ot = num_elements(R);
int nht = ot - ht;
int t = ot + uot;
vector[ot] R = oR;
real exp_adj_Rt;
vector[t] infections = rep_vector(0, t);
vector[ot] cum_infections;
vector[ot] infectiousness;
real growth = R_to_r(R[1], gt_rev_pmf, 1e-3);
// Initialise infections using daily growth
infections[1] = exp(initial_infections[1]);
if (obs_scale) {
infections[1] = infections[1] / frac_obs;
if (initial_as_scale) {
infections[1] = exp(initial_infections[1] - growth * uot);
if (obs_scale) {
infections[1] = infections[1] / frac_obs;
}
} else {
infections[1] = exp(initial_infections[1]);
}
if (uot > 1) {
real growth = exp(initial_growth[1]);
real exp_growth = exp(growth);
for (s in 2:uot) {
infections[s] = infections[s - 1] * growth;
infections[s] = infections[s - 1] * exp_growth;
}
}
// calculate cumulative infections
Expand Down
58 changes: 51 additions & 7 deletions inst/stan/functions/rt.stan
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,64 @@ vector update_Rt(int t, real R0, vector noise, array[] int bps,
/**
* Calculate the log-probability of the reproduction number (Rt) priors
*
* @param initial_infections Array of initial infection values
* @param initial_infections_scale Array of initial infection values
* @param bp_effects Vector of breakpoint effects
* @param bp_sd Array of breakpoint standard deviations
* @param bp_n Number of breakpoints
* @param prior_infections Prior mean for initial infections
*/
void rt_lp(array[] real initial_infections, vector bp_effects,
array[] real bp_sd, int bp_n, real prior_infections) {
void rt_lp(array[] real initial_infections_scale, vector bp_effects,
array[] real bp_sd, int bp_n, array[] int cases,
real initial_infections_guess) {
//breakpoint effects on Rt
if (bp_n > 0) {
bp_sd[1] ~ normal(0, 0.1) T[0,];
bp_effects ~ normal(0, bp_sd[1]);
}
// initial infections
initial_infections ~ normal(prior_infections, sqrt(prior_infections));

initial_infections_scale ~ normal(initial_infections_guess, 2);
}

/**
* Helper function for calculating r from R using Newton's method
*
* Code is based on Julia code from
* https://github.com/CDCgov/Rt-without-renewal/blob/d6344cc6e451e3e6c4188e4984247f890ae60795/EpiAware/test/predictive_checking/fast_approx_for_r.jl
* under Apache license 2.0.
*
* @param R Reproduction number
* @param r growth rate
* @param pmf generation time probability mass function (first index: 0)
*/
real R_to_r_newton_step(real R, real r, vector pmf) {
int len = num_elements(pmf);
vector[len] zero_series = linspaced_vector(len, 0, len - 1);
vector[len] exp_r = exp(-r * zero_series);
real ret = (R * dot_product(pmf, exp_r) - 1) /
(- R * dot_product(pmf .* zero_series, exp_r));
return(ret);
}

/**
* Estimate the growth rate r from reproduction number R. Used in the model to
* estimate the initial growth rate using Newton's method.
*
* Code is based on Julia code from
* https://github.com/CDCgov/Rt-without-renewal/blob/d6344cc6e451e3e6c4188e4984247f890ae60795/EpiAware/test/predictive_checking/fast_approx_for_r.jl
* under Apache license 2.0.
*
* @param R reproduction number
* @param gt_rev_pmf reverse probability mass function of the generation time
* @param abs_tol absolute tolerance of the solver
*/
real R_to_r(real R, vector gt_rev_pmf, real abs_tol) {
int gt_len = num_elements(gt_rev_pmf);
vector[gt_len] gt_pmf = reverse(gt_rev_pmf);
real mean_gt = dot_product(gt_pmf, linspaced_vector(gt_len, 0, gt_len - 1));
real r = (R - 1) / (R * mean_gt + 1);
real step = abs_tol + 1;
while (abs(step) > abs_tol) {
step = R_to_r_newton_step(R, r, gt_pmf);
r -= step;
}

return(r);
}
2 changes: 1 addition & 1 deletion inst/stan/simulate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ generated quantities {

infections[i] = to_row_vector(generate_infections(
to_vector(R[i]), seeding_time, gt_rev_pmf, initial_infections[i],
initial_growth[i], pop, future_time, obs_scale, frac_obs[i]
pop, future_time, obs_scale, frac_obs[i], initial_as_scale
));

if (delay_id) {
Expand Down
23 changes: 0 additions & 23 deletions man/estimate_early_dynamics.Rd

This file was deleted.

Loading
Loading