Skip to content

Commit

Permalink
Merge pull request #12 from mrc-ide/mrc-4339
Browse files Browse the repository at this point in the history
mrc-4339: Add a state saver object to configure returned years
  • Loading branch information
r-ash authored Jun 29, 2023
2 parents 5e7533b + 98d53b9 commit f8b33dd
Show file tree
Hide file tree
Showing 11 changed files with 313 additions and 143 deletions.
4 changes: 2 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393

run_base_model <- function(data, projection_parameters, sim_years, hiv_steps_per_year, hiv_age_stratification = "full") {
.Call(`_frogger_run_base_model`, data, projection_parameters, sim_years, hiv_steps_per_year, hiv_age_stratification)
run_base_model <- function(data, projection_parameters, sim_years, hiv_steps_per_year, output_steps, hiv_age_stratification = "full") {
.Call(`_frogger_run_base_model`, data, projection_parameters, sim_years, hiv_steps_per_year, output_steps, hiv_age_stratification)
}

serialize_vector <- function(data, path1, path2) {
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ remotes::install_github("mrc-ide/frogger", upgrade = FALSE)
* Add function for calculating incidence rate per sex https://github.com/mrc-ide/frogger/pull/7#discussion_r1217848792
* Make fit work with coarse ages (at the moment not reading all of the coarse stratified data)
* Remove duplicate reading of `hAG_SPAN_full`, read this as `hiv_age_groups_span` and `age_groups_hiv_span`
* Refactor `OutputState` to take a struct of state-space dimensions instead of unpacking the subset of parameters we
need. See https://github.com/mrc-ide/frogger/pull/12#discussion_r1245170775

## Leapfrog to Frogger glossary

Expand Down
54 changes: 35 additions & 19 deletions inst/fit_model/fit_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,38 +11,45 @@ Eigen::TensorMap <Eigen::Tensor<T, rank>> tensor_to_tensor_map(Eigen::Tensor <T,
return Eigen::TensorMap < Eigen::Tensor < T, rank >> (d.data(), d.dimensions());
}

void save_output(leapfrog::State<double> &state, std::string &output_path) {
// Save output
void save_output(leapfrog::StateSaver<double> &state_saver, std::string &output_path) {
auto state = state_saver.get_full_state();

std::filesystem::path out_path(output_path);
std::filesystem::path total_population_path = out_path / "total_population";
serialize::serialize_tensor<double, 2>(state.total_population, total_population_path);
std::filesystem::path births_path = out_path / "births";
serialize::serialize_tensor<double, 3>(state.total_population, total_population_path);

// Births is just a double so write it out
std::ofstream dest(births_path);
dest << state.births << std::endl;
dest.close();
std::filesystem::path births_path = out_path / "births";
serialize::serialize_tensor<double, 1>(state.births, births_path);

std::filesystem::path natural_deaths_path = out_path / "natural_deaths";
serialize::serialize_tensor<double, 2>(state.natural_deaths, natural_deaths_path);
serialize::serialize_tensor<double, 3>(state.natural_deaths, natural_deaths_path);

std::filesystem::path hiv_population_path = out_path / "hiv_population";
serialize::serialize_tensor<double, 2>(state.hiv_population, hiv_population_path);
serialize::serialize_tensor<double, 3>(state.hiv_population, hiv_population_path);

std::filesystem::path hiv_natural_deaths_path = out_path / "hiv_natural_deaths";
serialize::serialize_tensor<double, 2>(state.hiv_natural_deaths, hiv_natural_deaths_path);
serialize::serialize_tensor<double, 3>(state.hiv_natural_deaths, hiv_natural_deaths_path);

std::filesystem::path hiv_strat_adult_path = out_path / "hiv_strat_adult";
serialize::serialize_tensor<double, 3>(state.hiv_strat_adult, hiv_strat_adult_path);
serialize::serialize_tensor<double, 4>(state.hiv_strat_adult, hiv_strat_adult_path);

std::filesystem::path art_strat_adult_path = out_path / "art_strat_adult";
serialize::serialize_tensor<double, 4>(state.art_strat_adult, art_strat_adult_path);
serialize::serialize_tensor<double, 5>(state.art_strat_adult, art_strat_adult_path);

std::filesystem::path aids_deaths_no_art_path = out_path / "aids_deaths_no_art";
serialize::serialize_tensor<double, 3>(state.aids_deaths_no_art, aids_deaths_no_art_path);
serialize::serialize_tensor<double, 4>(state.aids_deaths_no_art, aids_deaths_no_art_path);

std::filesystem::path infections_path = out_path / "infections";
serialize::serialize_tensor<double, 2>(state.infections, infections_path);
serialize::serialize_tensor<double, 3>(state.infections, infections_path);

std::filesystem::path aids_deaths_art_path = out_path / "aids_deaths_art";
serialize::serialize_tensor<double, 4>(state.aids_deaths_art, aids_deaths_art_path);
serialize::serialize_tensor<double, 5>(state.aids_deaths_art, aids_deaths_art_path);

std::filesystem::path art_initiation_path = out_path / "art_initiation";
serialize::serialize_tensor<double, 3>(state.art_initiation, art_initiation_path);
serialize::serialize_tensor<double, 4>(state.art_initiation, art_initiation_path);

std::filesystem::path hiv_deaths_path = out_path / "hiv_deaths";
serialize::serialize_tensor<double, 2>(state.hiv_deaths, hiv_deaths_path);
serialize::serialize_tensor<double, 3>(state.hiv_deaths, hiv_deaths_path);
}

int main(int argc, char *argv[]) {
Expand Down Expand Up @@ -229,18 +236,27 @@ int main(int argc, char *argv[]) {
params.age_groups_hiv_15plus);
intermediate.reset();

std::vector<int> save_steps(61);
std::iota(save_steps.begin(), save_steps.end(), 0);
leapfrog::StateSaver<double> state_output(sim_years, save_steps, params.age_groups_pop, params.num_genders,
params.disease_stages, params.age_groups_hiv,
params.treatment_stages);
// Save initial state
state_output.save_state(state_current, 0);

// Each time step is mid-point of the year
for (int step = 1; step <= sim_years; ++step) {
state_next.reset();
leapfrog::run_general_pop_demographic_projection(step, params, state_current, state_next, intermediate);
leapfrog::run_hiv_pop_demographic_projection(step, params, state_current, state_next, intermediate);
leapfrog::run_hiv_model_simulation(step, params, state_current, state_next, intermediate);
state_output.save_state(state_next, step);
std::swap(state_current, state_next);
intermediate.reset();
}
std::cout << "Fit complete" << std::endl;

save_output(state_current, output_abs);
save_output(state_output, output_abs);

return 0;
}
13 changes: 11 additions & 2 deletions inst/include/frogger.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "general_demographic_projection.hpp"
#include "hiv_demographic_projection.hpp"
#include "model_simulation.hpp"
#include "state_saver.hpp"

namespace leapfrog {

Expand Down Expand Up @@ -32,7 +33,8 @@ void initialise_model_state(const Parameters<real_type> &pars,
}

template<typename real_type>
State<real_type> run_model(int time_steps, const Parameters<real_type> &pars) {
typename StateSaver<real_type>::OutputState run_model(int time_steps, std::vector<int> save_steps,
const Parameters<real_type> &pars) {
State<real_type> state(pars.age_groups_pop, pars.num_genders,
pars.disease_stages, pars.age_groups_hiv,
pars.treatment_stages);
Expand All @@ -44,16 +46,23 @@ State<real_type> run_model(int time_steps, const Parameters<real_type> &pars) {
pars.age_groups_hiv_15plus);
intermediate.reset();

StateSaver<real_type> state_output(time_steps, save_steps, pars.age_groups_pop, pars.num_genders,
pars.disease_stages, pars.age_groups_hiv,
pars.treatment_stages);
// Save initial state
state_output.save_state(state, 0);

// Each time step is mid-point of the year
for (int step = 1; step <= time_steps; ++step) {
state_next.reset();
run_general_pop_demographic_projection(step, pars, state, state_next, intermediate);
run_hiv_pop_demographic_projection(step, pars, state, state_next, intermediate);
run_hiv_model_simulation(step, pars, state, state_next, intermediate);
state_output.save_state(state_next, step);
std::swap(state, state_next);
intermediate.reset();
}
return state;
return state_output.get_full_state();
}

}
118 changes: 118 additions & 0 deletions inst/include/state_saver.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
#pragma once

#include "types.hpp"

namespace leapfrog {

template<typename real_type>
class StateSaver {
public:
struct OutputState {
Tensor3<real_type> total_population;
Tensor3<real_type> natural_deaths;
Tensor3<real_type> hiv_population;
Tensor3<real_type> hiv_natural_deaths;
Tensor4<real_type> hiv_strat_adult;
Tensor5<real_type> art_strat_adult;
Tensor1<real_type> births;
Tensor4<real_type> aids_deaths_no_art;
Tensor3<real_type> infections;
Tensor5<real_type> aids_deaths_art;
Tensor4<real_type> art_initiation;
Tensor3<real_type> hiv_deaths;

OutputState(int age_groups_pop,
int num_genders,
int disease_stages,
int age_groups_hiv,
int treatment_stages,
int no_output_years)
: total_population(age_groups_pop, num_genders, no_output_years),
natural_deaths(age_groups_pop, num_genders, no_output_years),
hiv_population(age_groups_pop, num_genders, no_output_years),
hiv_natural_deaths(age_groups_pop, num_genders, no_output_years),
hiv_strat_adult(disease_stages, age_groups_hiv, num_genders, no_output_years),
art_strat_adult(treatment_stages,
disease_stages,
age_groups_hiv,
num_genders,
no_output_years),
births(no_output_years),
aids_deaths_no_art(disease_stages, age_groups_hiv, num_genders, no_output_years),
infections(age_groups_pop, num_genders, no_output_years),
aids_deaths_art(treatment_stages, disease_stages, age_groups_hiv, num_genders, no_output_years),
art_initiation(disease_stages, age_groups_hiv, num_genders, no_output_years),
hiv_deaths(age_groups_pop, num_genders, no_output_years) {
total_population.setZero();
natural_deaths.setZero();
hiv_population.setZero();
hiv_natural_deaths.setZero();
hiv_strat_adult.setZero();
art_strat_adult.setZero();
births.setZero();
aids_deaths_no_art.setZero();
infections.setZero();
aids_deaths_art.setZero();
art_initiation.setZero();
hiv_deaths.setZero();
}
};

StateSaver(int time_steps,
std::vector<int> save_steps,
int age_groups_pop,
int num_genders,
int disease_stages,
int age_groups_hiv,
int treatment_stages) :
save_steps(save_steps),
full_state(age_groups_pop, num_genders, disease_stages, age_groups_hiv, treatment_stages, save_steps.size()) {
for (int step: save_steps) {
if (step < 0) {
std::stringstream ss;
ss << "Output step must be at least 0, got '" << step << "'." << std::endl;
throw std::runtime_error(ss.str());
}
if (step > time_steps) {
std::stringstream ss;
ss << "Output step can be at most number of time steps run which is '" << time_steps << "', got step '" << step
<< "'." << std::endl;
throw std::runtime_error(ss.str());
}
}
}


void save_state(const State<real_type> &state, int current_year) {
for (size_t i = 0; i < save_steps.size(); ++i) {
if (current_year == save_steps[i]) {
full_state.total_population.chip(i, full_state.total_population.NumDimensions - 1) = state.total_population;
full_state.natural_deaths.chip(i, full_state.natural_deaths.NumDimensions - 1) = state.natural_deaths;
full_state.hiv_population.chip(i, full_state.hiv_population.NumDimensions - 1) = state.hiv_population;
full_state.hiv_natural_deaths.chip(i, full_state.hiv_natural_deaths.NumDimensions - 1) =
state.hiv_natural_deaths;
full_state.hiv_strat_adult.chip(i, full_state.hiv_strat_adult.NumDimensions - 1) = state.hiv_strat_adult;
full_state.art_strat_adult.chip(i, full_state.art_strat_adult.NumDimensions - 1) = state.art_strat_adult;
full_state.births(i) = state.births;
full_state.aids_deaths_no_art.chip(i, full_state.aids_deaths_no_art.NumDimensions - 1) =
state.aids_deaths_no_art;
full_state.infections.chip(i, full_state.infections.NumDimensions - 1) = state.infections;
full_state.aids_deaths_art.chip(i, full_state.aids_deaths_art.NumDimensions - 1) = state.aids_deaths_art;
full_state.art_initiation.chip(i, full_state.art_initiation.NumDimensions - 1) = state.art_initiation;
full_state.hiv_deaths.chip(i, full_state.hiv_deaths.NumDimensions - 1) = state.hiv_deaths;
return;
}
}
}

const OutputState &get_full_state() const {
return full_state;
}


private:
std::vector<int> save_steps;
OutputState full_state;
};

}
3 changes: 3 additions & 0 deletions inst/include/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ using Tensor3 = Eigen::Tensor<real_type, 3>;
template<typename real_type>
using Tensor4 = Eigen::Tensor<real_type, 4>;

template<typename real_type>
using Tensor5 = Eigen::Tensor<real_type, 5>;

template<typename real_type>
struct Parameters {
int num_genders;
Expand Down
9 changes: 5 additions & 4 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,18 @@ Rcpp::Rostream<false>& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get();
#endif

// run_base_model
Rcpp::List run_base_model(const Rcpp::List data, const Rcpp::List projection_parameters, SEXP sim_years, SEXP hiv_steps_per_year, std::string hiv_age_stratification);
RcppExport SEXP _frogger_run_base_model(SEXP dataSEXP, SEXP projection_parametersSEXP, SEXP sim_yearsSEXP, SEXP hiv_steps_per_yearSEXP, SEXP hiv_age_stratificationSEXP) {
Rcpp::List run_base_model(const Rcpp::List data, const Rcpp::List projection_parameters, SEXP sim_years, SEXP hiv_steps_per_year, Rcpp::NumericVector output_steps, std::string hiv_age_stratification);
RcppExport SEXP _frogger_run_base_model(SEXP dataSEXP, SEXP projection_parametersSEXP, SEXP sim_yearsSEXP, SEXP hiv_steps_per_yearSEXP, SEXP output_stepsSEXP, SEXP hiv_age_stratificationSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< const Rcpp::List >::type data(dataSEXP);
Rcpp::traits::input_parameter< const Rcpp::List >::type projection_parameters(projection_parametersSEXP);
Rcpp::traits::input_parameter< SEXP >::type sim_years(sim_yearsSEXP);
Rcpp::traits::input_parameter< SEXP >::type hiv_steps_per_year(hiv_steps_per_yearSEXP);
Rcpp::traits::input_parameter< Rcpp::NumericVector >::type output_steps(output_stepsSEXP);
Rcpp::traits::input_parameter< std::string >::type hiv_age_stratification(hiv_age_stratificationSEXP);
rcpp_result_gen = Rcpp::wrap(run_base_model(data, projection_parameters, sim_years, hiv_steps_per_year, hiv_age_stratification));
rcpp_result_gen = Rcpp::wrap(run_base_model(data, projection_parameters, sim_years, hiv_steps_per_year, output_steps, hiv_age_stratification));
return rcpp_result_gen;
END_RCPP
}
Expand Down Expand Up @@ -53,7 +54,7 @@ END_RCPP
}

static const R_CallMethodDef CallEntries[] = {
{"_frogger_run_base_model", (DL_FUNC) &_frogger_run_base_model, 5},
{"_frogger_run_base_model", (DL_FUNC) &_frogger_run_base_model, 6},
{"_frogger_serialize_vector", (DL_FUNC) &_frogger_serialize_vector, 3},
{"_frogger_deserialize_vector", (DL_FUNC) &_frogger_deserialize_vector, 2},
{NULL, NULL, 0}
Expand Down
Loading

0 comments on commit f8b33dd

Please sign in to comment.