Skip to content

Commit

Permalink
Make sure to use num_chains instead of n_chain everywhere, fix docs, …
Browse files Browse the repository at this point in the history
…removes unneeded else branch
  • Loading branch information
SteveBronder committed Jul 19, 2021
1 parent bbf4d37 commit 4f82e24
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 65 deletions.
32 changes: 15 additions & 17 deletions src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,17 @@ int hmc_nuts_dense_e_adapt(
* @tparam DiagnosticWriter A type derived from `stan::callbacks::writer`
* @tparam InitWriter A type derived from `stan::callbacks::writer`
* @param[in] model Input model to test (with data already instantiated)
* @param[in] num_chains The number of chains to run in parallel. `init`,
* `init_inv_metric`, `init_writer`, `sample_writer`, and `diagnostic_writer` must
* be the same length as this value.
* @param[in] init An std vector of init var contexts for initialization of each
chain.
* chain.
* @param[in] init_inv_metric An std vector of var contexts exposing an initial
diagonal inverse Euclidean metric for each chain (must be positive definite)
* diagonal inverse Euclidean metric for each chain (must be positive definite)
* @param[in] random_seed random seed for the random number generator
* @param[in] init_chain_id first chain id. The pseudo random number generator
will advance by for each chain by an integer sequence from `init_chain_id` to
`num_chains`
* will advance by for each chain by an integer sequence from `init_chain_id` to
* `init_chain_id+num_chains-1`
* @param[in] init_radius radius to initialize
* @param[in] num_warmup Number of warmup samples
* @param[in] num_samples Number of samples
Expand All @@ -198,10 +201,7 @@ int hmc_nuts_dense_e_adapt(
inits of each chain.
* @param[in,out] sample_writer std vector of Writers for draws of each chain.
* @param[in,out] diagnostic_writer std vector of Writers for diagnostic
information of each chain.
* @param[in] num_chains The number of chains to run in parallel. `init`,
`init_inv_metric`, `init_writer`, `sample_writer`, and `diagnostic_writer` must
be the same length as this value.
* information of each chain.
* @return error_codes::OK if successful
*/
template <class Model, typename InitContextPtr, typename InitInvContextPtr,
Expand All @@ -225,7 +225,7 @@ int hmc_nuts_dense_e_adapt(
stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0,
init_buffer, term_buffer, window, interrupt, logger, init_writer[0],
sample_writer[0], diagnostic_writer[0]);
} else {
}
using sample_t = stan::mcmc::adapt_dense_e_nuts<Model, boost::ecuyer1988>;
std::vector<boost::ecuyer1988> rngs;
rngs.reserve(num_chains);
Expand All @@ -235,7 +235,7 @@ int hmc_nuts_dense_e_adapt(
samplers.reserve(num_chains);
try {
for (int i = 0; i < num_chains; ++i) {
rngs.emplace_back(util::create_rng(random_seed, init_chain_id, i));
rngs.emplace_back(util::create_rng(random_seed, init_chain_id + i));
cont_vectors.emplace_back(util::initialize(model, *init[i], rngs[i],
init_radius, true, logger,
init_writer[i]));
Expand Down Expand Up @@ -276,7 +276,6 @@ int hmc_nuts_dense_e_adapt(
},
tbb::simple_partitioner());
return error_codes::OK;
}
}

/**
Expand All @@ -290,12 +289,15 @@ int hmc_nuts_dense_e_adapt(
* @tparam SamplerWriter A type derived from `stan::callbacks::writer`
* @tparam DiagnosticWriter A type derived from `stan::callbacks::writer`
* @param[in] model Input model to test (with data already instantiated)
* @param[in] num_chains The number of chains to run in parallel. `init`,
* `init_writer`, `sample_writer`, and `diagnostic_writer` must be the same
* length as this value.
* @param[in] init An std vector of init var contexts for initialization of each
* chain.
* @param[in] random_seed random seed for the random number generator
* @param[in] init_chain_id first chain id. The pseudo random number generator
* will advance by for each chain by an integer sequence from `init_chain_id` to
* `num_chains`
* `init_chain_id+num_chains-1`
* @param[in] init_radius radius to initialize
* @param[in] num_warmup Number of warmup samples
* @param[in] num_samples Number of samples
Expand All @@ -319,9 +321,6 @@ int hmc_nuts_dense_e_adapt(
* @param[in,out] sample_writer std vector of Writers for draws of each chain.
* @param[in,out] diagnostic_writer std vector of Writers for diagnostic
* information of each chain.
* @param[in] num_chains The number of chains to run in parallel. `init`,
* `init_writer`, `sample_writer`, and `diagnostic_writer` must be the same
* length as this value.
* @return error_codes::OK if successful
*/
template <class Model, typename InitContextPtr, typename InitWriter,
Expand All @@ -344,7 +343,7 @@ int hmc_nuts_dense_e_adapt(
max_depth, delta, gamma, kappa, t0, init_buffer, term_buffer, window,
interrupt, logger, init_writer[0], sample_writer[0],
diagnostic_writer[0]);
} else {
}
std::vector<std::unique_ptr<stan::io::dump>> unit_e_metrics;
unit_e_metrics.reserve(num_chains);
for (size_t i = 0; i < num_chains; ++i) {
Expand All @@ -357,7 +356,6 @@ int hmc_nuts_dense_e_adapt(
stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0,
init_buffer, term_buffer, window, interrupt, logger, init_writer,
sample_writer, diagnostic_writer);
}
}

} // namespace sample
Expand Down
33 changes: 16 additions & 17 deletions src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,17 @@ int hmc_nuts_diag_e_adapt(
* @tparam DiagnosticWriter A type derived from `stan::callbacks::writer`
* @tparam InitWriter A type derived from `stan::callbacks::writer`
* @param[in] model Input model to test (with data already instantiated)
* @param[in] num_chains The number of chains to run in parallel. `init`,
* `init_inv_metric`, `init_writer`, `sample_writer`, and `diagnostic_writer` must
* be the same length as this value.
* @param[in] init An std vector of init var contexts for initialization of each
chain.
* chain.
* @param[in] init_inv_metric An std vector of var contexts exposing an initial
diagonal inverse Euclidean metric for each chain (must be positive definite)
* @param[in] random_seed random seed for the random number generator
* @param[in] init_chain_id first chain id. The pseudo random number generator
will advance by for each chain by an integer sequence from `init_chain_id` to
`num_chains`
* will advance for each chain by an integer sequence from `init_chain_id` to
* `init_chain_id + num_chains - 1`
* @param[in] init_radius radius to initialize
* @param[in] num_warmup Number of warmup samples
* @param[in] num_samples Number of samples
Expand All @@ -196,13 +199,10 @@ int hmc_nuts_diag_e_adapt(
* @param[in,out] interrupt Callback for interrupts
* @param[in,out] logger Logger for messages
* @param[in,out] init_writer std vector of Writer callbacks for unconstrained
inits of each chain.
* inits of each chain.
* @param[in,out] sample_writer std vector of Writers for draws of each chain.
* @param[in,out] diagnostic_writer std vector of Writers for diagnostic
information of each chain.
* @param[in] num_chains The number of chains to run in parallel. `init`,
`init_inv_metric`, `init_writer`, `sample_writer`, and `diagnostic_writer` must
be the same length as this value.
* information of each chain.
* @return error_codes::OK if successful
*/
template <class Model, typename InitContextPtr, typename InitInvContextPtr,
Expand All @@ -226,7 +226,7 @@ int hmc_nuts_diag_e_adapt(
stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0,
init_buffer, term_buffer, window, interrupt, logger, init_writer[0],
sample_writer[0], diagnostic_writer[0]);
} else {
}
using sample_t = stan::mcmc::adapt_diag_e_nuts<Model, boost::ecuyer1988>;
std::vector<boost::ecuyer1988> rngs;
rngs.reserve(num_chains);
Expand All @@ -236,7 +236,7 @@ int hmc_nuts_diag_e_adapt(
samplers.reserve(num_chains);
try {
for (int i = 0; i < num_chains; ++i) {
rngs.emplace_back(util::create_rng(random_seed, init_chain_id, i));
rngs.emplace_back(util::create_rng(random_seed, init_chain_id + i));
cont_vectors.emplace_back(util::initialize(model, *init[i], rngs[i],
init_radius, true, logger,
init_writer[i]));
Expand Down Expand Up @@ -277,7 +277,7 @@ int hmc_nuts_diag_e_adapt(
},
tbb::simple_partitioner());
return error_codes::OK;
}

}

/**
Expand All @@ -291,12 +291,15 @@ int hmc_nuts_diag_e_adapt(
* @tparam DiagnosticWriter A type derived from `stan::callbacks::writer`
* @tparam InitWriter A type derived from `stan::callbacks::writer`
* @param[in] model Input model to test (with data already instantiated)
* @param[in] num_chains The number of chains to run in parallel. `init`,
* `init_writer`, `sample_writer`, and `diagnostic_writer` must be the same
* length as this value.
* @param[in] init An std vector of init var contexts for initialization of each
* chain.
* @param[in] random_seed random seed for the random number generator
* @param[in] init_chain_id first chain id. The pseudo random number generator
* will advance by for each chain by an integer sequence from `init_chain_id` to
* `num_chains`
* `init_chain_id+num_chains-1`
* @param[in] init_radius radius to initialize
* @param[in] num_warmup Number of warmup samples
* @param[in] num_samples Number of samples
Expand All @@ -320,9 +323,6 @@ int hmc_nuts_diag_e_adapt(
* @param[in,out] sample_writer std vector of Writers for draws of each chain.
* @param[in,out] diagnostic_writer std vector of Writers for diagnostic
* information of each chain.
* @param[in] num_chains The number of chains to run in parallel. `init`,
* `init_writer`, `sample_writer`, and `diagnostic_writer` must be the same
* length as this value.
* @return error_codes::OK if successful
*/
template <class Model, typename InitContextPtr, typename InitWriter,
Expand All @@ -345,7 +345,7 @@ int hmc_nuts_diag_e_adapt(
max_depth, delta, gamma, kappa, t0, init_buffer, term_buffer, window,
interrupt, logger, init_writer[0], sample_writer[0],
diagnostic_writer[0]);
} else {
}
std::vector<std::unique_ptr<stan::io::dump>> unit_e_metrics;
unit_e_metrics.reserve(num_chains);
for (size_t i = 0; i < num_chains; ++i) {
Expand All @@ -358,7 +358,6 @@ int hmc_nuts_diag_e_adapt(
stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0,
init_buffer, term_buffer, window, interrupt, logger, init_writer,
sample_writer, diagnostic_writer);
}
}

} // namespace sample
Expand Down
47 changes: 22 additions & 25 deletions src/stan/services/util/create_rng.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,28 @@ namespace stan {
namespace services {
namespace util {

/**
* Creates a pseudo random number generator from a random seed
* and a chain id by initializing the PRNG with the seed and
* then advancing past pow(2, 50) times the chain ID draws to
* ensure different chains sample from different segments of the
* pseudo random number sequence.
*
* Chain IDs should be kept to larger values than one to ensure
* that the draws used to initialized transformed data are not
* duplicated.
*
* @param[in] seed the random seed
* @param[in] init_chain_id the chain id
* @param[in] chain_num For multi-chain, the ch
* @return a boost::ecuyer1988 instance
*/
inline boost::ecuyer1988 create_rng(unsigned int seed,
unsigned int init_chain_id,
unsigned int chain_num = 0) {
using boost::uintmax_t;
constexpr static uintmax_t DISCARD_STRIDE = static_cast<uintmax_t>(1) << 50;
boost::ecuyer1988 rng(seed);
rng.discard(DISCARD_STRIDE * (init_chain_id + chain_num));
return rng;
}
/**
* Creates a pseudo random number generator from a random seed
* and a chain id by initializing the PRNG with the seed and
* then advancing past pow(2, 50) times the chain ID draws to
* ensure different chains sample from different segments of the
* pseudo random number sequence.
*
* Chain IDs should be kept to larger values than one to ensure
* that the draws used to initialized transformed data are not
* duplicated.
*
* @param[in] seed the random seed
* @param[in] chain the chain id
* @return a boost::ecuyer1988 instance
*/
inline boost::ecuyer1988 create_rng(unsigned int seed, unsigned int chain) {
using boost::uintmax_t;
static constexpr uintmax_t DISCARD_STRIDE = static_cast<uintmax_t>(1) << 50;
boost::ecuyer1988 rng(seed);
rng.discard(DISCARD_STRIDE * chain);
return rng;
}

} // namespace util
} // namespace services
Expand Down
6 changes: 4 additions & 2 deletions src/stan/services/util/generate_transitions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ namespace util {
* @param[in,out] base_rng random number generator
* @param[in,out] callback interrupt callback called once an iteration
* @param[in,out] logger logger for messages
* @param[in] num_chains The number of chains used in the program. This
* is used in generate transitions to print out the chain number.
*/
template <class Model, class RNG>
void generate_transitions(stan::mcmc::base_mcmc& sampler, int num_iterations,
Expand All @@ -45,15 +47,15 @@ void generate_transitions(stan::mcmc::base_mcmc& sampler, int num_iterations,
stan::mcmc::sample& init_s, Model& model,
RNG& base_rng, callbacks::interrupt& callback,
callbacks::logger& logger, size_t chain_id = 1,
size_t n_chain = 1) {
size_t num_chains = 1) {
for (int m = 0; m < num_iterations; ++m) {
callback();

if (refresh > 0
&& (start + m + 1 == finish || m == 0 || (m + 1) % refresh == 0)) {
int it_print_width = std::ceil(std::log10(static_cast<double>(finish)));
std::stringstream message;
if (n_chain != 1) {
if (num_chains != 1) {
message << "Chain [" << chain_id << "] ";
}
message << "Iteration: ";
Expand Down
8 changes: 4 additions & 4 deletions src/stan/services/util/run_adaptive_sampler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace util {
* @param[in,out] sample_writer writer for draws
* @param[in,out] diagnostic_writer writer for diagnostic information
* @param[in] chain_id The id for a given chain.
* @param[in] n_chain The number of chains used in the program. This
* @param[in] num_chains The number of chains used in the program. This
* is used in generate transitions to print out the chain number.
*/
template <typename Sampler, typename Model, typename RNG>
Expand All @@ -47,7 +47,7 @@ void run_adaptive_sampler(Sampler& sampler, Model& model,
callbacks::logger& logger,
callbacks::writer& sample_writer,
callbacks::writer& diagnostic_writer,
size_t chain_id = 1, size_t n_chain = 1) {
size_t chain_id = 1, size_t num_chains = 1) {
Eigen::Map<Eigen::VectorXd> cont_params(cont_vector.data(),
cont_vector.size());

Expand All @@ -71,7 +71,7 @@ void run_adaptive_sampler(Sampler& sampler, Model& model,
auto start_warm = std::chrono::steady_clock::now();
util::generate_transitions(sampler, num_warmup, 0, num_warmup + num_samples,
num_thin, refresh, save_warmup, true, writer, s,
model, rng, interrupt, logger, chain_id, n_chain);
model, rng, interrupt, logger, chain_id, num_chains);
auto end_warm = std::chrono::steady_clock::now();
double warm_delta_t = std::chrono::duration_cast<std::chrono::milliseconds>(
end_warm - start_warm)
Expand All @@ -85,7 +85,7 @@ void run_adaptive_sampler(Sampler& sampler, Model& model,
util::generate_transitions(sampler, num_samples, num_warmup,
num_warmup + num_samples, num_thin, refresh, true,
false, writer, s, model, rng, interrupt, logger,
chain_id, n_chain);
chain_id, num_chains);
auto end_sample = std::chrono::steady_clock::now();
double sample_delta_t = std::chrono::duration_cast<std::chrono::milliseconds>(
end_sample - start_sample)
Expand Down

0 comments on commit 4f82e24

Please sign in to comment.