diff --git a/src/stan/services/util/run_adaptive_sampler.hpp b/src/stan/services/util/run_adaptive_sampler.hpp index b2683e3c0b1..96221c8abf4 100644 --- a/src/stan/services/util/run_adaptive_sampler.hpp +++ b/src/stan/services/util/run_adaptive_sampler.hpp @@ -34,6 +34,9 @@ namespace util { * @param[in,out] logger logger for messages * @param[in,out] sample_writer writer for draws * @param[in,out] diagnostic_writer writer for diagnostic information + * @param[in] chain_id The id for a given chain. + * @param[in] n_chain The number of chains used in the program. This + * is used in generate transitions to print out the chain number. */ template void run_adaptive_sampler(Sampler& sampler, Model& model, diff --git a/src/test/unit/services/sample/hmc_nuts_dense_e_adapt_parallel_match_test.cpp b/src/test/unit/services/sample/hmc_nuts_dense_e_adapt_parallel_match_test.cpp new file mode 100644 index 00000000000..3426afba375 --- /dev/null +++ b/src/test/unit/services/sample/hmc_nuts_dense_e_adapt_parallel_match_test.cpp @@ -0,0 +1,109 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +auto&& blah = stan::math::init_threadpool_tbb(); + + +static constexpr size_t num_chains = 4; +class ServicesSampleHmcNutsDenseEAdaptParMatch : public testing::Test { + public: + ServicesSampleHmcNutsDenseEAdaptParMatch() + : model(std::make_unique( + data_context, 0, &model_log)) { + for (int i = 0; i < num_chains; ++i) { + init.push_back(stan::test::unit::instrumented_writer{}); + par_parameters.emplace_back(std::make_unique(), "#"); + seq_parameters.emplace_back(std::make_unique(), "#"); + diagnostic.push_back(stan::test::unit::instrumented_writer{}); + context.push_back(std::make_shared()); + } + } + stan::io::empty_var_context data_context; + std::stringstream model_log; + stan::test::unit::instrumented_logger logger; + std::vector init; + using str_writer = stan::callbacks::unique_stream_writer; + std::vector par_parameters; + std::vector seq_parameters; + std::vector diagnostic; + std::vector> context; + std::unique_ptr model; +}; + +/** + * This test checks that running multiple chains in one call + * with the same initial id is the same as running multiple calls + * with incrementing chain ids. + */ +TEST_F(ServicesSampleHmcNutsDenseEAdaptParMatch, single_multi_match) { + constexpr unsigned int random_seed = 0; + constexpr unsigned int chain = 0; + constexpr double init_radius = 0; + constexpr int num_warmup = 200; + constexpr int num_samples = 400; + constexpr int num_thin = 5; + constexpr bool save_warmup = true; + constexpr int refresh = 0; + constexpr double stepsize = 0.1; + constexpr double stepsize_jitter = 0; + constexpr int max_depth = 8; + constexpr double delta = .1; + constexpr double gamma = .1; + constexpr double kappa = .1; + constexpr double t0 = .1; + constexpr unsigned int init_buffer = 50; + constexpr unsigned int term_buffer = 50; + constexpr unsigned int window = 100; + stan::test::unit::instrumented_interrupt interrupt; + EXPECT_EQ(interrupt.call_count(), 0); + int return_code = stan::services::sample::hmc_nuts_dense_e_adapt( + *model, num_chains, context, random_seed, chain, init_radius, num_warmup, + num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter, + max_depth, delta, gamma, kappa, t0, init_buffer, term_buffer, window, + interrupt, logger, init, par_parameters, diagnostic); + + EXPECT_EQ(0, return_code); + + int num_output_lines = (num_warmup + num_samples) / num_thin; + EXPECT_EQ((num_warmup + num_samples) * num_chains, interrupt.call_count()); + for (int i = 0; i < num_chains; ++i) { + stan::test::unit::instrumented_writer seq_init; + stan::test::unit::instrumented_writer seq_diagnostic; + return_code = stan::services::sample::hmc_nuts_dense_e_adapt( + *model, *(context[i]), random_seed, i, init_radius, num_warmup, + num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter, + max_depth, delta, gamma, kappa, t0, init_buffer, term_buffer, window, + interrupt, logger, seq_init, seq_parameters[i], seq_diagnostic); + EXPECT_EQ(0, return_code); + } + std::vector par_res; + for (int i = 0; i < num_chains; ++i) { + auto par_str = par_parameters[i].get_stream().str(); + auto sub_par_str = par_str.substr(par_str.find("Elements") - 1); + std::istringstream sub_par_stream(sub_par_str); + Eigen::MatrixXd par_mat + = stan::test::read_stan_sample_csv(sub_par_stream, 80, 9); + par_res.push_back(par_mat); + } + std::vector seq_res; + for (int i = 0; i < num_chains; ++i) { + auto seq_str = seq_parameters[i].get_stream().str(); + auto sub_seq_str = seq_str.substr(seq_str.find("Elements") - 1); + std::istringstream sub_seq_stream(sub_seq_str); + Eigen::MatrixXd seq_mat + = stan::test::read_stan_sample_csv(sub_seq_stream, 80, 9); + seq_res.push_back(seq_mat); + } + for (int i = 0; i < num_chains; ++i) { + Eigen::MatrixXd diff_res + = (par_res[i].array() - seq_res[i].array()).matrix(); + EXPECT_MATRIX_EQ(diff_res, Eigen::MatrixXd::Zero(80, 9)); + } +} diff --git a/src/test/unit/services/sample/hmc_nuts_diag_e_adapt_parallel_match_test.cpp b/src/test/unit/services/sample/hmc_nuts_diag_e_adapt_parallel_match_test.cpp index 8f9274b7dac..c4167419cbf 100644 --- a/src/test/unit/services/sample/hmc_nuts_diag_e_adapt_parallel_match_test.cpp +++ b/src/test/unit/services/sample/hmc_nuts_diag_e_adapt_parallel_match_test.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -9,43 +10,6 @@ auto&& blah = stan::math::init_threadpool_tbb(); -namespace stan { -namespace test { -/** - * Read a CSV into an Eigen matrix. - * @param in An input string stream holding the CSV - * @param rows Number of rows - * @param cols Number of columns. - */ -Eigen::MatrixXd read_stan_sample_csv(std::istringstream& in, int rows, - int cols) { - std::string line; - int row = 0; - int col = 0; - Eigen::MatrixXd res = Eigen::MatrixXd(rows, cols); - while (std::getline(in, line)) { - if (line.find("#") != std::string::npos) { - continue; - } - const char* ptr = line.c_str(); - int len = line.length(); - col = 0; - - const char* start = ptr; - for (int i = 0; i < len; i++) { - if (ptr[i] == ',') { - res(row, col++) = atof(start); - start = ptr + i + 1; - } - } - res(row, col) = atof(start); - row++; - } - return res; -} -} // namespace test -} // namespace stan - static constexpr size_t num_chains = 4; class ServicesSampleHmcNutsDiagEAdaptParMatch : public testing::Test { public: @@ -72,6 +36,11 @@ class ServicesSampleHmcNutsDiagEAdaptParMatch : public testing::Test { std::unique_ptr model; }; +/** + * This test checks that running multiple chains in one call + * with the same initial id is the same as running multiple calls + * with incrementing chain ids. + */ TEST_F(ServicesSampleHmcNutsDiagEAdaptParMatch, single_multi_match) { constexpr unsigned int random_seed = 0; constexpr unsigned int chain = 0; diff --git a/src/test/unit/services/util.hpp b/src/test/unit/services/util.hpp new file mode 100644 index 00000000000..f283fa6f49f --- /dev/null +++ b/src/test/unit/services/util.hpp @@ -0,0 +1,45 @@ +#ifndef STAN_SRC_TEST_UNIT_SERVICES_UTIL_HPP +#define STAN_SRC_TEST_UNIT_SERVICES_UTIL_HPP + +#include +#include +#include +#include + +namespace stan { +namespace test { +/** + * Read a CSV into an Eigen matrix. + * @param in An input string stream holding the CSV + * @param rows Number of rows + * @param cols Number of columns. + */ +Eigen::MatrixXd read_stan_sample_csv(std::istringstream& in, int rows, + int cols) { + std::string line; + int row = 0; + int col = 0; + Eigen::MatrixXd res = Eigen::MatrixXd(rows, cols); + while (std::getline(in, line)) { + if (line.find("#") != std::string::npos) { + continue; + } + const char* ptr = line.c_str(); + int len = line.length(); + col = 0; + + const char* start = ptr; + for (int i = 0; i < len; i++) { + if (ptr[i] == ',') { + res(row, col++) = atof(start); + start = ptr + i + 1; + } + } + res(row, col) = atof(start); + row++; + } + return res; +} +} // namespace test +} // namespace stan +#endif