Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mrc-4339: Add a state saver object to configure returned years #12

Merged
merged 7 commits into from
Jun 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393

run_base_model <- function(data, projection_parameters, sim_years, hiv_steps_per_year, hiv_age_stratification = "full") {
.Call(`_frogger_run_base_model`, data, projection_parameters, sim_years, hiv_steps_per_year, hiv_age_stratification)
run_base_model <- function(data, projection_parameters, sim_years, hiv_steps_per_year, output_steps, hiv_age_stratification = "full") {
.Call(`_frogger_run_base_model`, data, projection_parameters, sim_years, hiv_steps_per_year, output_steps, hiv_age_stratification)
}

serialize_vector <- function(data, path1, path2) {
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ remotes::install_github("mrc-ide/frogger", upgrade = FALSE)
* Add function for calculating incidence rate per sex https://github.com/mrc-ide/frogger/pull/7#discussion_r1217848792
* Make fit work with coarse ages (at the moment not reading all of the coarse stratified data)
* Remove duplicate reading of `hAG_SPAN_full`, read this as `hiv_age_groups_span` and `age_groups_hiv_span`
* Refactor `OutputState` to take a struct of state-space dimensions instead of unpacking the subset of parameters we
need. See https://github.com/mrc-ide/frogger/pull/12#discussion_r1245170775

## Leapfrog to Frogger glossary

Expand Down
54 changes: 35 additions & 19 deletions inst/fit_model/fit_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,38 +11,45 @@ Eigen::TensorMap <Eigen::Tensor<T, rank>> tensor_to_tensor_map(Eigen::Tensor <T,
return Eigen::TensorMap < Eigen::Tensor < T, rank >> (d.data(), d.dimensions());
}

void save_output(leapfrog::State<double> &state, std::string &output_path) {
// Save output
void save_output(leapfrog::StateSaver<double> &state_saver, std::string &output_path) {
auto state = state_saver.get_full_state();

std::filesystem::path out_path(output_path);
std::filesystem::path total_population_path = out_path / "total_population";
serialize::serialize_tensor<double, 2>(state.total_population, total_population_path);
std::filesystem::path births_path = out_path / "births";
serialize::serialize_tensor<double, 3>(state.total_population, total_population_path);

// Births is just a double so write it out
std::ofstream dest(births_path);
dest << state.births << std::endl;
dest.close();
std::filesystem::path births_path = out_path / "births";
serialize::serialize_tensor<double, 1>(state.births, births_path);

std::filesystem::path natural_deaths_path = out_path / "natural_deaths";
serialize::serialize_tensor<double, 2>(state.natural_deaths, natural_deaths_path);
serialize::serialize_tensor<double, 3>(state.natural_deaths, natural_deaths_path);

std::filesystem::path hiv_population_path = out_path / "hiv_population";
serialize::serialize_tensor<double, 2>(state.hiv_population, hiv_population_path);
serialize::serialize_tensor<double, 3>(state.hiv_population, hiv_population_path);

std::filesystem::path hiv_natural_deaths_path = out_path / "hiv_natural_deaths";
serialize::serialize_tensor<double, 2>(state.hiv_natural_deaths, hiv_natural_deaths_path);
serialize::serialize_tensor<double, 3>(state.hiv_natural_deaths, hiv_natural_deaths_path);

std::filesystem::path hiv_strat_adult_path = out_path / "hiv_strat_adult";
serialize::serialize_tensor<double, 3>(state.hiv_strat_adult, hiv_strat_adult_path);
serialize::serialize_tensor<double, 4>(state.hiv_strat_adult, hiv_strat_adult_path);

std::filesystem::path art_strat_adult_path = out_path / "art_strat_adult";
serialize::serialize_tensor<double, 4>(state.art_strat_adult, art_strat_adult_path);
serialize::serialize_tensor<double, 5>(state.art_strat_adult, art_strat_adult_path);

std::filesystem::path aids_deaths_no_art_path = out_path / "aids_deaths_no_art";
serialize::serialize_tensor<double, 3>(state.aids_deaths_no_art, aids_deaths_no_art_path);
serialize::serialize_tensor<double, 4>(state.aids_deaths_no_art, aids_deaths_no_art_path);

std::filesystem::path infections_path = out_path / "infections";
serialize::serialize_tensor<double, 2>(state.infections, infections_path);
serialize::serialize_tensor<double, 3>(state.infections, infections_path);

std::filesystem::path aids_deaths_art_path = out_path / "aids_deaths_art";
serialize::serialize_tensor<double, 4>(state.aids_deaths_art, aids_deaths_art_path);
serialize::serialize_tensor<double, 5>(state.aids_deaths_art, aids_deaths_art_path);

std::filesystem::path art_initiation_path = out_path / "art_initiation";
serialize::serialize_tensor<double, 3>(state.art_initiation, art_initiation_path);
serialize::serialize_tensor<double, 4>(state.art_initiation, art_initiation_path);

std::filesystem::path hiv_deaths_path = out_path / "hiv_deaths";
serialize::serialize_tensor<double, 2>(state.hiv_deaths, hiv_deaths_path);
serialize::serialize_tensor<double, 3>(state.hiv_deaths, hiv_deaths_path);
}

int main(int argc, char *argv[]) {
Expand Down Expand Up @@ -229,18 +236,27 @@ int main(int argc, char *argv[]) {
params.age_groups_hiv_15plus);
intermediate.reset();

std::vector<int> save_steps(61);
std::iota(save_steps.begin(), save_steps.end(), 0);
leapfrog::StateSaver<double> state_output(sim_years, save_steps, params.age_groups_pop, params.num_genders,
params.disease_stages, params.age_groups_hiv,
params.treatment_stages);
// Save initial state
state_output.save_state(state_current, 0);

// Each time step is mid-point of the year
for (int step = 1; step <= sim_years; ++step) {
state_next.reset();
leapfrog::run_general_pop_demographic_projection(step, params, state_current, state_next, intermediate);
leapfrog::run_hiv_pop_demographic_projection(step, params, state_current, state_next, intermediate);
leapfrog::run_hiv_model_simulation(step, params, state_current, state_next, intermediate);
state_output.save_state(state_next, step);
std::swap(state_current, state_next);
intermediate.reset();
}
std::cout << "Fit complete" << std::endl;

save_output(state_current, output_abs);
save_output(state_output, output_abs);

return 0;
}
13 changes: 11 additions & 2 deletions inst/include/frogger.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "general_demographic_projection.hpp"
#include "hiv_demographic_projection.hpp"
#include "model_simulation.hpp"
#include "state_saver.hpp"

namespace leapfrog {

Expand Down Expand Up @@ -32,7 +33,8 @@ void initialise_model_state(const Parameters<real_type> &pars,
}

template<typename real_type>
State<real_type> run_model(int time_steps, const Parameters<real_type> &pars) {
typename StateSaver<real_type>::OutputState run_model(int time_steps, std::vector<int> save_steps,
const Parameters<real_type> &pars) {
State<real_type> state(pars.age_groups_pop, pars.num_genders,
pars.disease_stages, pars.age_groups_hiv,
pars.treatment_stages);
Expand All @@ -44,16 +46,23 @@ State<real_type> run_model(int time_steps, const Parameters<real_type> &pars) {
pars.age_groups_hiv_15plus);
intermediate.reset();

StateSaver<real_type> state_output(time_steps, save_steps, pars.age_groups_pop, pars.num_genders,
pars.disease_stages, pars.age_groups_hiv,
pars.treatment_stages);
// Save initial state
state_output.save_state(state, 0);

// Each time step is mid-point of the year
for (int step = 1; step <= time_steps; ++step) {
state_next.reset();
run_general_pop_demographic_projection(step, pars, state, state_next, intermediate);
run_hiv_pop_demographic_projection(step, pars, state, state_next, intermediate);
run_hiv_model_simulation(step, pars, state, state_next, intermediate);
state_output.save_state(state_next, step);
std::swap(state, state_next);
intermediate.reset();
}
return state;
return state_output.get_full_state();
}

}
118 changes: 118 additions & 0 deletions inst/include/state_saver.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
#pragma once

#include "types.hpp"

namespace leapfrog {

template<typename real_type>
class StateSaver {
public:
struct OutputState {
Tensor3<real_type> total_population;
Tensor3<real_type> natural_deaths;
Tensor3<real_type> hiv_population;
Tensor3<real_type> hiv_natural_deaths;
Tensor4<real_type> hiv_strat_adult;
Tensor5<real_type> art_strat_adult;
Tensor1<real_type> births;
Tensor4<real_type> aids_deaths_no_art;
Tensor3<real_type> infections;
Tensor5<real_type> aids_deaths_art;
Tensor4<real_type> art_initiation;
Tensor3<real_type> hiv_deaths;

OutputState(int age_groups_pop,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

most these values, defining the dimensions, are probably subsets of the model parameters? do you want to pass in the parameters object here, or create some sub-struct within the object so that this call signature does not grow unmanageably, and also that a state saver and its underlying model run are always compatible?

Do we imagine that all of these are always wanted?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes they are, so these are all the dimensions of our state-space. We've started thinking about splitting this out in https://github.com/mrc-ide/frogger/pull/11/files#diff-b335630551682c19a781afebcf4d07bf978fb1f8ac04c6bf87428ed5106870f5R51 as part of that PR I will update the names and also create some sub-struct for the parameters. If it looks ok I would wait until that PR is pinned down and use whatever struct comes from that. I'll add a TODO note here into the README

int num_genders,
int disease_stages,
int age_groups_hiv,
int treatment_stages,
int no_output_years)
: total_population(age_groups_pop, num_genders, no_output_years),
natural_deaths(age_groups_pop, num_genders, no_output_years),
hiv_population(age_groups_pop, num_genders, no_output_years),
hiv_natural_deaths(age_groups_pop, num_genders, no_output_years),
hiv_strat_adult(disease_stages, age_groups_hiv, num_genders, no_output_years),
art_strat_adult(treatment_stages,
disease_stages,
age_groups_hiv,
num_genders,
no_output_years),
births(no_output_years),
aids_deaths_no_art(disease_stages, age_groups_hiv, num_genders, no_output_years),
infections(age_groups_pop, num_genders, no_output_years),
aids_deaths_art(treatment_stages, disease_stages, age_groups_hiv, num_genders, no_output_years),
art_initiation(disease_stages, age_groups_hiv, num_genders, no_output_years),
hiv_deaths(age_groups_pop, num_genders, no_output_years) {
total_population.setZero();
natural_deaths.setZero();
hiv_population.setZero();
hiv_natural_deaths.setZero();
hiv_strat_adult.setZero();
art_strat_adult.setZero();
births.setZero();
aids_deaths_no_art.setZero();
infections.setZero();
aids_deaths_art.setZero();
art_initiation.setZero();
hiv_deaths.setZero();
}
};

StateSaver(int time_steps,
std::vector<int> save_steps,
int age_groups_pop,
int num_genders,
int disease_stages,
int age_groups_hiv,
int treatment_stages) :
save_steps(save_steps),
full_state(age_groups_pop, num_genders, disease_stages, age_groups_hiv, treatment_stages, save_steps.size()) {
for (int step: save_steps) {
if (step < 0) {
std::stringstream ss;
ss << "Output step must be at least 0, got '" << step << "'." << std::endl;
throw std::runtime_error(ss.str());
}
if (step > time_steps) {
std::stringstream ss;
ss << "Output step can be at most number of time steps run which is '" << time_steps << "', got step '" << step
<< "'." << std::endl;
throw std::runtime_error(ss.str());
}
}
}


void save_state(const State<real_type> &state, int current_year) {
for (size_t i = 0; i < save_steps.size(); ++i) {
if (current_year == save_steps[i]) {
full_state.total_population.chip(i, full_state.total_population.NumDimensions - 1) = state.total_population;
full_state.natural_deaths.chip(i, full_state.natural_deaths.NumDimensions - 1) = state.natural_deaths;
full_state.hiv_population.chip(i, full_state.hiv_population.NumDimensions - 1) = state.hiv_population;
full_state.hiv_natural_deaths.chip(i, full_state.hiv_natural_deaths.NumDimensions - 1) =
state.hiv_natural_deaths;
full_state.hiv_strat_adult.chip(i, full_state.hiv_strat_adult.NumDimensions - 1) = state.hiv_strat_adult;
full_state.art_strat_adult.chip(i, full_state.art_strat_adult.NumDimensions - 1) = state.art_strat_adult;
full_state.births(i) = state.births;
full_state.aids_deaths_no_art.chip(i, full_state.aids_deaths_no_art.NumDimensions - 1) =
state.aids_deaths_no_art;
full_state.infections.chip(i, full_state.infections.NumDimensions - 1) = state.infections;
full_state.aids_deaths_art.chip(i, full_state.aids_deaths_art.NumDimensions - 1) = state.aids_deaths_art;
full_state.art_initiation.chip(i, full_state.art_initiation.NumDimensions - 1) = state.art_initiation;
full_state.hiv_deaths.chip(i, full_state.hiv_deaths.NumDimensions - 1) = state.hiv_deaths;
return;
}
}
}

const OutputState &get_full_state() const {
return full_state;
}


private:
std::vector<int> save_steps;
OutputState full_state;
};

}
3 changes: 3 additions & 0 deletions inst/include/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ using Tensor3 = Eigen::Tensor<real_type, 3>;
template<typename real_type>
using Tensor4 = Eigen::Tensor<real_type, 4>;

template<typename real_type>
using Tensor5 = Eigen::Tensor<real_type, 5>;

template<typename real_type>
struct Parameters {
int num_genders;
Expand Down
9 changes: 5 additions & 4 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,18 @@ Rcpp::Rostream<false>& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get();
#endif

// run_base_model
Rcpp::List run_base_model(const Rcpp::List data, const Rcpp::List projection_parameters, SEXP sim_years, SEXP hiv_steps_per_year, std::string hiv_age_stratification);
RcppExport SEXP _frogger_run_base_model(SEXP dataSEXP, SEXP projection_parametersSEXP, SEXP sim_yearsSEXP, SEXP hiv_steps_per_yearSEXP, SEXP hiv_age_stratificationSEXP) {
Rcpp::List run_base_model(const Rcpp::List data, const Rcpp::List projection_parameters, SEXP sim_years, SEXP hiv_steps_per_year, Rcpp::NumericVector output_steps, std::string hiv_age_stratification);
RcppExport SEXP _frogger_run_base_model(SEXP dataSEXP, SEXP projection_parametersSEXP, SEXP sim_yearsSEXP, SEXP hiv_steps_per_yearSEXP, SEXP output_stepsSEXP, SEXP hiv_age_stratificationSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< const Rcpp::List >::type data(dataSEXP);
Rcpp::traits::input_parameter< const Rcpp::List >::type projection_parameters(projection_parametersSEXP);
Rcpp::traits::input_parameter< SEXP >::type sim_years(sim_yearsSEXP);
Rcpp::traits::input_parameter< SEXP >::type hiv_steps_per_year(hiv_steps_per_yearSEXP);
Rcpp::traits::input_parameter< Rcpp::NumericVector >::type output_steps(output_stepsSEXP);
Rcpp::traits::input_parameter< std::string >::type hiv_age_stratification(hiv_age_stratificationSEXP);
rcpp_result_gen = Rcpp::wrap(run_base_model(data, projection_parameters, sim_years, hiv_steps_per_year, hiv_age_stratification));
rcpp_result_gen = Rcpp::wrap(run_base_model(data, projection_parameters, sim_years, hiv_steps_per_year, output_steps, hiv_age_stratification));
return rcpp_result_gen;
END_RCPP
}
Expand Down Expand Up @@ -53,7 +54,7 @@ END_RCPP
}

static const R_CallMethodDef CallEntries[] = {
{"_frogger_run_base_model", (DL_FUNC) &_frogger_run_base_model, 5},
{"_frogger_run_base_model", (DL_FUNC) &_frogger_run_base_model, 6},
{"_frogger_serialize_vector", (DL_FUNC) &_frogger_serialize_vector, 3},
{"_frogger_deserialize_vector", (DL_FUNC) &_frogger_deserialize_vector, 2},
{NULL, NULL, 0}
Expand Down
Loading