diff --git a/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp index a129a39e3b7..2733cff1f40 100644 --- a/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp @@ -169,7 +169,8 @@ int hmc_nuts_dense_e_adapt( * @tparam InitWriter A type derived from `stan::callbacks::writer` * @param[in] model Input model to test (with data already instantiated) * @param[in] num_chains The number of chains to run in parallel. `init`, - * `init_inv_metric`, `init_writer`, `sample_writer`, and `diagnostic_writer` must + * `init_inv_metric`, `init_writer`, `sample_writer`, and `diagnostic_writer` + must * be the same length as this value. * @param[in] init An std vector of init var contexts for initialization of each * chain. @@ -226,56 +227,55 @@ int hmc_nuts_dense_e_adapt( init_buffer, term_buffer, window, interrupt, logger, init_writer[0], sample_writer[0], diagnostic_writer[0]); } - using sample_t = stan::mcmc::adapt_dense_e_nuts; - std::vector rngs; - rngs.reserve(num_chains); - std::vector> cont_vectors; - cont_vectors.reserve(num_chains); - std::vector samplers; - samplers.reserve(num_chains); - try { - for (int i = 0; i < num_chains; ++i) { - rngs.emplace_back(util::create_rng(random_seed, init_chain_id + i)); - cont_vectors.emplace_back(util::initialize(model, *init[i], rngs[i], - init_radius, true, logger, - init_writer[i])); - Eigen::MatrixXd inv_metric = util::read_dense_inv_metric( - *init_inv_metric[i], model.num_params_r(), logger); - util::validate_dense_inv_metric(inv_metric, logger); + using sample_t = stan::mcmc::adapt_dense_e_nuts; + std::vector rngs; + rngs.reserve(num_chains); + std::vector> cont_vectors; + cont_vectors.reserve(num_chains); + std::vector samplers; + samplers.reserve(num_chains); + try { + for (int i = 0; i < num_chains; ++i) { + rngs.emplace_back(util::create_rng(random_seed, init_chain_id + i)); + cont_vectors.emplace_back(util::initialize( + model, *init[i], rngs[i], init_radius, true, logger, init_writer[i])); + Eigen::MatrixXd inv_metric = util::read_dense_inv_metric( + *init_inv_metric[i], model.num_params_r(), logger); + util::validate_dense_inv_metric(inv_metric, logger); - samplers.emplace_back(model, rngs[i]); - samplers[i].set_metric(inv_metric); - samplers[i].set_nominal_stepsize(stepsize); - samplers[i].set_stepsize_jitter(stepsize_jitter); - samplers[i].set_max_depth(max_depth); + samplers.emplace_back(model, rngs[i]); + samplers[i].set_metric(inv_metric); + samplers[i].set_nominal_stepsize(stepsize); + samplers[i].set_stepsize_jitter(stepsize_jitter); + samplers[i].set_max_depth(max_depth); - samplers[i].get_stepsize_adaptation().set_mu(log(10 * stepsize)); - samplers[i].get_stepsize_adaptation().set_delta(delta); - samplers[i].get_stepsize_adaptation().set_gamma(gamma); - samplers[i].get_stepsize_adaptation().set_kappa(kappa); - samplers[i].get_stepsize_adaptation().set_t0(t0); - samplers[i].set_window_params(num_warmup, init_buffer, term_buffer, - window, logger); - } - } catch (const std::domain_error& e) { - return error_codes::CONFIG; + samplers[i].get_stepsize_adaptation().set_mu(log(10 * stepsize)); + samplers[i].get_stepsize_adaptation().set_delta(delta); + samplers[i].get_stepsize_adaptation().set_gamma(gamma); + samplers[i].get_stepsize_adaptation().set_kappa(kappa); + samplers[i].get_stepsize_adaptation().set_t0(t0); + samplers[i].set_window_params(num_warmup, init_buffer, term_buffer, + window, logger); } - tbb::parallel_for( - tbb::blocked_range(0, num_chains, 1), - [num_warmup, num_samples, num_thin, refresh, save_warmup, num_chains, - init_chain_id, &samplers, &model, &rngs, &interrupt, &logger, - &sample_writer, &cont_vectors, - &diagnostic_writer](const tbb::blocked_range& r) { - for (size_t i = r.begin(); i != r.end(); ++i) { - util::run_adaptive_sampler( - samplers[i], model, cont_vectors[i], num_warmup, num_samples, - num_thin, refresh, save_warmup, rngs[i], interrupt, logger, - sample_writer[i], diagnostic_writer[i], init_chain_id + i, - num_chains); - } - }, - tbb::simple_partitioner()); - return error_codes::OK; + } catch (const std::domain_error& e) { + return error_codes::CONFIG; + } + tbb::parallel_for(tbb::blocked_range(0, num_chains, 1), + [num_warmup, num_samples, num_thin, refresh, save_warmup, + num_chains, init_chain_id, &samplers, &model, &rngs, + &interrupt, &logger, &sample_writer, &cont_vectors, + &diagnostic_writer](const tbb::blocked_range& r) { + for (size_t i = r.begin(); i != r.end(); ++i) { + util::run_adaptive_sampler( + samplers[i], model, cont_vectors[i], num_warmup, + num_samples, num_thin, refresh, save_warmup, + rngs[i], interrupt, logger, sample_writer[i], + diagnostic_writer[i], init_chain_id + i, + num_chains); + } + }, + tbb::simple_partitioner()); + return error_codes::OK; } /** @@ -344,18 +344,18 @@ int hmc_nuts_dense_e_adapt( interrupt, logger, init_writer[0], sample_writer[0], diagnostic_writer[0]); } - std::vector> unit_e_metrics; - unit_e_metrics.reserve(num_chains); - for (size_t i = 0; i < num_chains; ++i) { - unit_e_metrics.emplace_back(std::make_unique( - util::create_unit_e_dense_inv_metric(model.num_params_r()))); - } - return hmc_nuts_dense_e_adapt( - model, num_chains, init, unit_e_metrics, random_seed, init_chain_id, - init_radius, num_warmup, num_samples, num_thin, save_warmup, refresh, - stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0, - init_buffer, term_buffer, window, interrupt, logger, init_writer, - sample_writer, diagnostic_writer); + std::vector> unit_e_metrics; + unit_e_metrics.reserve(num_chains); + for (size_t i = 0; i < num_chains; ++i) { + unit_e_metrics.emplace_back(std::make_unique( + util::create_unit_e_dense_inv_metric(model.num_params_r()))); + } + return hmc_nuts_dense_e_adapt( + model, num_chains, init, unit_e_metrics, random_seed, init_chain_id, + init_radius, num_warmup, num_samples, num_thin, save_warmup, refresh, + stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0, + init_buffer, term_buffer, window, interrupt, logger, init_writer, + sample_writer, diagnostic_writer); } } // namespace sample diff --git a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp index 203ad37ce89..9689fa4a9a1 100644 --- a/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp +++ b/src/stan/services/sample/hmc_nuts_diag_e_adapt.hpp @@ -170,7 +170,8 @@ int hmc_nuts_diag_e_adapt( * @tparam InitWriter A type derived from `stan::callbacks::writer` * @param[in] model Input model to test (with data already instantiated) * @param[in] num_chains The number of chains to run in parallel. `init`, - * `init_inv_metric`, `init_writer`, `sample_writer`, and `diagnostic_writer` must + * `init_inv_metric`, `init_writer`, `sample_writer`, and `diagnostic_writer` + must * be the same length as this value. * @param[in] init An std vector of init var contexts for initialization of each * chain. @@ -227,57 +228,55 @@ int hmc_nuts_diag_e_adapt( init_buffer, term_buffer, window, interrupt, logger, init_writer[0], sample_writer[0], diagnostic_writer[0]); } - using sample_t = stan::mcmc::adapt_diag_e_nuts; - std::vector rngs; - rngs.reserve(num_chains); - std::vector> cont_vectors; - cont_vectors.reserve(num_chains); - std::vector samplers; - samplers.reserve(num_chains); - try { - for (int i = 0; i < num_chains; ++i) { - rngs.emplace_back(util::create_rng(random_seed, init_chain_id + i)); - cont_vectors.emplace_back(util::initialize(model, *init[i], rngs[i], - init_radius, true, logger, - init_writer[i])); - samplers.emplace_back(model, rngs[i]); - Eigen::VectorXd inv_metric = util::read_diag_inv_metric( - *init_inv_metric[i], model.num_params_r(), logger); - util::validate_diag_inv_metric(inv_metric, logger); + using sample_t = stan::mcmc::adapt_diag_e_nuts; + std::vector rngs; + rngs.reserve(num_chains); + std::vector> cont_vectors; + cont_vectors.reserve(num_chains); + std::vector samplers; + samplers.reserve(num_chains); + try { + for (int i = 0; i < num_chains; ++i) { + rngs.emplace_back(util::create_rng(random_seed, init_chain_id + i)); + cont_vectors.emplace_back(util::initialize( + model, *init[i], rngs[i], init_radius, true, logger, init_writer[i])); + samplers.emplace_back(model, rngs[i]); + Eigen::VectorXd inv_metric = util::read_diag_inv_metric( + *init_inv_metric[i], model.num_params_r(), logger); + util::validate_diag_inv_metric(inv_metric, logger); - samplers[i].set_metric(inv_metric); - samplers[i].set_nominal_stepsize(stepsize); - samplers[i].set_stepsize_jitter(stepsize_jitter); - samplers[i].set_max_depth(max_depth); + samplers[i].set_metric(inv_metric); + samplers[i].set_nominal_stepsize(stepsize); + samplers[i].set_stepsize_jitter(stepsize_jitter); + samplers[i].set_max_depth(max_depth); - samplers[i].get_stepsize_adaptation().set_mu(log(10 * stepsize)); - samplers[i].get_stepsize_adaptation().set_delta(delta); - samplers[i].get_stepsize_adaptation().set_gamma(gamma); - samplers[i].get_stepsize_adaptation().set_kappa(kappa); - samplers[i].get_stepsize_adaptation().set_t0(t0); - samplers[i].set_window_params(num_warmup, init_buffer, term_buffer, - window, logger); - } - } catch (const std::domain_error& e) { - return error_codes::CONFIG; + samplers[i].get_stepsize_adaptation().set_mu(log(10 * stepsize)); + samplers[i].get_stepsize_adaptation().set_delta(delta); + samplers[i].get_stepsize_adaptation().set_gamma(gamma); + samplers[i].get_stepsize_adaptation().set_kappa(kappa); + samplers[i].get_stepsize_adaptation().set_t0(t0); + samplers[i].set_window_params(num_warmup, init_buffer, term_buffer, + window, logger); } - tbb::parallel_for( - tbb::blocked_range(0, num_chains, 1), - [num_warmup, num_samples, num_thin, refresh, save_warmup, num_chains, - init_chain_id, &samplers, &model, &rngs, &interrupt, &logger, - &sample_writer, &cont_vectors, - &diagnostic_writer](const tbb::blocked_range& r) { - for (size_t i = r.begin(); i != r.end(); ++i) { - util::run_adaptive_sampler( - samplers[i], model, cont_vectors[i], num_warmup, num_samples, - num_thin, refresh, save_warmup, rngs[i], interrupt, logger, - sample_writer[i], diagnostic_writer[i], init_chain_id + i, - num_chains); - } - }, - tbb::simple_partitioner()); - return error_codes::OK; - + } catch (const std::domain_error& e) { + return error_codes::CONFIG; + } + tbb::parallel_for(tbb::blocked_range(0, num_chains, 1), + [num_warmup, num_samples, num_thin, refresh, save_warmup, + num_chains, init_chain_id, &samplers, &model, &rngs, + &interrupt, &logger, &sample_writer, &cont_vectors, + &diagnostic_writer](const tbb::blocked_range& r) { + for (size_t i = r.begin(); i != r.end(); ++i) { + util::run_adaptive_sampler( + samplers[i], model, cont_vectors[i], num_warmup, + num_samples, num_thin, refresh, save_warmup, + rngs[i], interrupt, logger, sample_writer[i], + diagnostic_writer[i], init_chain_id + i, + num_chains); + } + }, + tbb::simple_partitioner()); + return error_codes::OK; } /** @@ -346,18 +345,18 @@ int hmc_nuts_diag_e_adapt( interrupt, logger, init_writer[0], sample_writer[0], diagnostic_writer[0]); } - std::vector> unit_e_metrics; - unit_e_metrics.reserve(num_chains); - for (size_t i = 0; i < num_chains; ++i) { - unit_e_metrics.emplace_back(std::make_unique( - util::create_unit_e_diag_inv_metric(model.num_params_r()))); - } - return hmc_nuts_diag_e_adapt( - model, num_chains, init, unit_e_metrics, random_seed, init_chain_id, - init_radius, num_warmup, num_samples, num_thin, save_warmup, refresh, - stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0, - init_buffer, term_buffer, window, interrupt, logger, init_writer, - sample_writer, diagnostic_writer); + std::vector> unit_e_metrics; + unit_e_metrics.reserve(num_chains); + for (size_t i = 0; i < num_chains; ++i) { + unit_e_metrics.emplace_back(std::make_unique( + util::create_unit_e_diag_inv_metric(model.num_params_r()))); + } + return hmc_nuts_diag_e_adapt( + model, num_chains, init, unit_e_metrics, random_seed, init_chain_id, + init_radius, num_warmup, num_samples, num_thin, save_warmup, refresh, + stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0, + init_buffer, term_buffer, window, interrupt, logger, init_writer, + sample_writer, diagnostic_writer); } } // namespace sample diff --git a/src/stan/services/util/create_rng.hpp b/src/stan/services/util/create_rng.hpp index 39f8e4ce9be..d63d6c79f1b 100644 --- a/src/stan/services/util/create_rng.hpp +++ b/src/stan/services/util/create_rng.hpp @@ -7,28 +7,28 @@ namespace stan { namespace services { namespace util { - /** - * Creates a pseudo random number generator from a random seed - * and a chain id by initializing the PRNG with the seed and - * then advancing past pow(2, 50) times the chain ID draws to - * ensure different chains sample from different segments of the - * pseudo random number sequence. - * - * Chain IDs should be kept to larger values than one to ensure - * that the draws used to initialized transformed data are not - * duplicated. - * - * @param[in] seed the random seed - * @param[in] chain the chain id - * @return a boost::ecuyer1988 instance - */ - inline boost::ecuyer1988 create_rng(unsigned int seed, unsigned int chain) { - using boost::uintmax_t; - static constexpr uintmax_t DISCARD_STRIDE = static_cast(1) << 50; - boost::ecuyer1988 rng(seed); - rng.discard(DISCARD_STRIDE * chain); - return rng; - } +/** + * Creates a pseudo random number generator from a random seed + * and a chain id by initializing the PRNG with the seed and + * then advancing past pow(2, 50) times the chain ID draws to + * ensure different chains sample from different segments of the + * pseudo random number sequence. + * + * Chain IDs should be kept to larger values than one to ensure + * that the draws used to initialized transformed data are not + * duplicated. + * + * @param[in] seed the random seed + * @param[in] chain the chain id + * @return a boost::ecuyer1988 instance + */ +inline boost::ecuyer1988 create_rng(unsigned int seed, unsigned int chain) { + using boost::uintmax_t; + static constexpr uintmax_t DISCARD_STRIDE = static_cast(1) << 50; + boost::ecuyer1988 rng(seed); + rng.discard(DISCARD_STRIDE * chain); + return rng; +} } // namespace util } // namespace services diff --git a/src/stan/services/util/run_adaptive_sampler.hpp b/src/stan/services/util/run_adaptive_sampler.hpp index ea0f3452225..7115b138fc3 100644 --- a/src/stan/services/util/run_adaptive_sampler.hpp +++ b/src/stan/services/util/run_adaptive_sampler.hpp @@ -71,7 +71,8 @@ void run_adaptive_sampler(Sampler& sampler, Model& model, auto start_warm = std::chrono::steady_clock::now(); util::generate_transitions(sampler, num_warmup, 0, num_warmup + num_samples, num_thin, refresh, save_warmup, true, writer, s, - model, rng, interrupt, logger, chain_id, num_chains); + model, rng, interrupt, logger, chain_id, + num_chains); auto end_warm = std::chrono::steady_clock::now(); double warm_delta_t = std::chrono::duration_cast( end_warm - start_warm)