Skip to content

Commit

Permalink
Merge pull request #3287 from stan-dev/fix/rng-init
Browse files Browse the repository at this point in the history
Ensure mixmax rng is not initialized with all zeros
  • Loading branch information
WardBrian authored May 22, 2024
2 parents 62b2a19 + ef3b05b commit 356d206
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 59 deletions.
5 changes: 4 additions & 1 deletion src/stan/services/util/create_rng.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ namespace util {
* @return an stan::rng_t instance
*/
inline rng_t create_rng(unsigned int seed, unsigned int chain) {
rng_t rng(seed + chain);
// RNG state is 128 bits, but user only provides 64 total bits
// Additionally, there are issues if all 128 bits are 0, hence
// the 1 as the second argument
rng_t rng(0, 1, seed, chain);
return rng;
}

Expand Down
4 changes: 2 additions & 2 deletions src/test/unit/mcmc/hmc/nuts/base_nuts_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ TEST(McmcNutsBaseNuts, transition) {
EXPECT_EQ((2 << (sampler.get_max_depth() - 1)) - 1, sampler.n_leapfrog_);
EXPECT_FALSE(sampler.divergent_);

EXPECT_EQ(-31 * init_momentum, s.cont_params()(0));
EXPECT_EQ(23 * init_momentum, s.cont_params()(0));
EXPECT_EQ(0, s.log_prob());
EXPECT_EQ(1, s.accept_stat());
EXPECT_EQ("", debug.str());
Expand All @@ -373,7 +373,7 @@ TEST(McmcNutsBaseNuts, transition) {
}

TEST(McmcNutsBaseNuts, transition_egde_momenta) {
stan::rng_t base_rng = stan::services::util::create_rng(424243, 0);
stan::rng_t base_rng = stan::services::util::create_rng(42424253, 0);

int model_size = 1;
double init_momentum = 1.5;
Expand Down
14 changes: 7 additions & 7 deletions src/test/unit/mcmc/hmc/nuts/softabs_nuts_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,15 +338,15 @@ TEST(McmcSoftAbsNuts, transition_test) {

stan::mcmc::sample s = sampler.transition(init_sample, logger);

EXPECT_EQ(5, sampler.depth_);
EXPECT_EQ((2 << 4) - 1, sampler.n_leapfrog_);
EXPECT_EQ(3, sampler.depth_);
EXPECT_EQ((2 << 3) - 1, sampler.n_leapfrog_);
EXPECT_FALSE(sampler.divergent_);

EXPECT_FLOAT_EQ(-1.7373296, s.cont_params()(0));
EXPECT_FLOAT_EQ(1.0898665, s.cont_params()(1));
EXPECT_FLOAT_EQ(-0.38303182, s.cont_params()(2));
EXPECT_FLOAT_EQ(-2.1764181, s.log_prob());
EXPECT_FLOAT_EQ(0.9993856, s.accept_stat());
EXPECT_FLOAT_EQ(0.74693149, s.cont_params()(0));
EXPECT_FLOAT_EQ(-0.74414188, s.cont_params()(1));
EXPECT_FLOAT_EQ(0.60859376, s.cont_params()(2));
EXPECT_FLOAT_EQ(-0.74102008, s.log_prob());
EXPECT_FLOAT_EQ(0.99934167, s.accept_stat());
EXPECT_EQ("", debug.str());
EXPECT_EQ("", info.str());
EXPECT_EQ("", warn.str());
Expand Down
14 changes: 7 additions & 7 deletions src/test/unit/mcmc/hmc/nuts/unit_e_nuts_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,15 +338,15 @@ TEST(McmcUnitENuts, transition_test) {

stan::mcmc::sample s = sampler.transition(init_sample, logger);

EXPECT_EQ(5, sampler.depth_);
EXPECT_EQ((2 << 4) - 1, sampler.n_leapfrog_);
EXPECT_EQ(3, sampler.depth_);
EXPECT_EQ((2 << 3) - 1, sampler.n_leapfrog_);
EXPECT_FALSE(sampler.divergent_);

EXPECT_FLOAT_EQ(-1.7890506, s.cont_params()(0));
EXPECT_FLOAT_EQ(1.2320533, s.cont_params()(1));
EXPECT_FLOAT_EQ(-0.62397981, s.cont_params()(2));
EXPECT_FLOAT_EQ(-2.554004, s.log_prob());
EXPECT_FLOAT_EQ(0.99910343, s.accept_stat());
EXPECT_FLOAT_EQ(0.70149082, s.cont_params()(0));
EXPECT_FLOAT_EQ(-0.69831347, s.cont_params()(1));
EXPECT_FLOAT_EQ(0.54392564, s.cont_params()(2));
EXPECT_FLOAT_EQ(-0.63779306, s.log_prob());
EXPECT_FLOAT_EQ(0.99912512, s.accept_stat());
EXPECT_EQ("", debug.str());
EXPECT_EQ("", info.str());
EXPECT_EQ("", warn.str());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ TEST(McmcStaticUniform, unit_e_transition) {

stan::mcmc::sample s = sampler.transition(init_sample, logger);

EXPECT_FLOAT_EQ(1.5896972, s.cont_params()(0));
EXPECT_FLOAT_EQ(-1.2635686, s.log_prob());
EXPECT_FLOAT_EQ(0.9994188, s.accept_stat());
EXPECT_FLOAT_EQ(1.0920367, s.cont_params()(0));
EXPECT_FLOAT_EQ(-0.59627211, s.log_prob());
EXPECT_FLOAT_EQ(0.99985325, s.accept_stat());
EXPECT_EQ("", debug.str());
EXPECT_EQ("", info.str());
EXPECT_EQ("", warn.str());
Expand Down Expand Up @@ -78,9 +78,9 @@ TEST(McmcStaticUniform, diag_e_transition) {

stan::mcmc::sample s = sampler.transition(init_sample, logger);

EXPECT_FLOAT_EQ(1.5896972, s.cont_params()(0));
EXPECT_FLOAT_EQ(-1.2635686, s.log_prob());
EXPECT_FLOAT_EQ(0.9994188, s.accept_stat());
EXPECT_FLOAT_EQ(1.0920367, s.cont_params()(0));
EXPECT_FLOAT_EQ(-0.59627211, s.log_prob());
EXPECT_FLOAT_EQ(0.99985325, s.accept_stat());
EXPECT_EQ("", debug.str());
EXPECT_EQ("", info.str());
EXPECT_EQ("", warn.str());
Expand Down Expand Up @@ -115,9 +115,9 @@ TEST(McmcStaticUniform, dense_e_transition) {

stan::mcmc::sample s = sampler.transition(init_sample, logger);

EXPECT_FLOAT_EQ(1.5896972, s.cont_params()(0));
EXPECT_FLOAT_EQ(-1.2635686, s.log_prob());
EXPECT_FLOAT_EQ(0.9994188, s.accept_stat());
EXPECT_FLOAT_EQ(1.0920367, s.cont_params()(0));
EXPECT_FLOAT_EQ(-0.59627211, s.log_prob());
EXPECT_FLOAT_EQ(0.99985325, s.accept_stat());
EXPECT_EQ("", debug.str());
EXPECT_EQ("", info.str());
EXPECT_EQ("", warn.str());
Expand Down Expand Up @@ -152,9 +152,9 @@ TEST(McmcStaticUniform, softabs_transition) {

stan::mcmc::sample s = sampler.transition(init_sample, logger);

EXPECT_FLOAT_EQ(1.5338461, s.cont_params()(0));
EXPECT_FLOAT_EQ(-1.176342, s.log_prob());
EXPECT_FLOAT_EQ(0.9996115, s.accept_stat());
EXPECT_FLOAT_EQ(1.0826443, s.cont_params()(0));
EXPECT_FLOAT_EQ(-0.58605933, s.log_prob());
EXPECT_FLOAT_EQ(0.99989599, s.accept_stat());
EXPECT_EQ("", debug.str());
EXPECT_EQ("", info.str());
EXPECT_EQ("", warn.str());
Expand Down Expand Up @@ -189,9 +189,9 @@ TEST(McmcStaticUniform, adapt_unit_e_transition) {

stan::mcmc::sample s = sampler.transition(init_sample, logger);

EXPECT_FLOAT_EQ(1.5896972, s.cont_params()(0));
EXPECT_FLOAT_EQ(-1.2635686, s.log_prob());
EXPECT_FLOAT_EQ(0.9994188, s.accept_stat());
EXPECT_FLOAT_EQ(1.0920367, s.cont_params()(0));
EXPECT_FLOAT_EQ(-0.59627211, s.log_prob());
EXPECT_FLOAT_EQ(0.99985325, s.accept_stat());
EXPECT_EQ("", debug.str());
EXPECT_EQ("", info.str());
EXPECT_EQ("", warn.str());
Expand Down Expand Up @@ -226,9 +226,9 @@ TEST(McmcStaticUniform, adapt_diag_e_transition) {

stan::mcmc::sample s = sampler.transition(init_sample, logger);

EXPECT_FLOAT_EQ(1.5896972, s.cont_params()(0));
EXPECT_FLOAT_EQ(-1.2635686, s.log_prob());
EXPECT_FLOAT_EQ(0.9994188, s.accept_stat());
EXPECT_FLOAT_EQ(1.0920367, s.cont_params()(0));
EXPECT_FLOAT_EQ(-0.59627211, s.log_prob());
EXPECT_FLOAT_EQ(0.99985325, s.accept_stat());
EXPECT_EQ("", debug.str());
EXPECT_EQ("", info.str());
EXPECT_EQ("", warn.str());
Expand Down Expand Up @@ -263,9 +263,9 @@ TEST(McmcStaticUniform, adapt_dense_e_transition) {

stan::mcmc::sample s = sampler.transition(init_sample, logger);

EXPECT_FLOAT_EQ(1.5896972, s.cont_params()(0));
EXPECT_FLOAT_EQ(-1.2635686, s.log_prob());
EXPECT_FLOAT_EQ(0.9994188, s.accept_stat());
EXPECT_FLOAT_EQ(1.0920367, s.cont_params()(0));
EXPECT_FLOAT_EQ(-0.59627211, s.log_prob());
EXPECT_FLOAT_EQ(0.99985325, s.accept_stat());
EXPECT_EQ("", debug.str());
EXPECT_EQ("", info.str());
EXPECT_EQ("", warn.str());
Expand Down Expand Up @@ -300,9 +300,9 @@ TEST(McmcStaticUniform, adapt_softabs_e_transition) {

stan::mcmc::sample s = sampler.transition(init_sample, logger);

EXPECT_FLOAT_EQ(1.5338461, s.cont_params()(0));
EXPECT_FLOAT_EQ(-1.176342, s.log_prob());
EXPECT_FLOAT_EQ(0.9996115, s.accept_stat());
EXPECT_FLOAT_EQ(1.0826443, s.cont_params()(0));
EXPECT_FLOAT_EQ(-0.58605933, s.log_prob());
EXPECT_FLOAT_EQ(0.99989599, s.accept_stat());
EXPECT_EQ("", debug.str());
EXPECT_EQ("", info.str());
EXPECT_EQ("", warn.str());
Expand Down
4 changes: 2 additions & 2 deletions src/test/unit/mcmc/hmc/xhmc/base_xhmc_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ TEST(McmcXHMCBaseXHMC, divergence_test) {
}

TEST(McmcXHMCBaseXHMC, transition) {
stan::rng_t base_rng = stan::services::util::create_rng(0, 0);
stan::rng_t base_rng = stan::services::util::create_rng(1234, 0);

int model_size = 1;
double init_momentum = 1.5;
Expand All @@ -245,7 +245,7 @@ TEST(McmcXHMCBaseXHMC, transition) {

stan::mcmc::sample s = sampler.transition(init_sample, logger);

EXPECT_EQ(-31 * init_momentum, s.cont_params()(0));
EXPECT_EQ(5 * init_momentum, s.cont_params()(0));
EXPECT_EQ(0, s.log_prob());
EXPECT_EQ(1, s.accept_stat());
EXPECT_EQ("", debug.str());
Expand Down
16 changes: 8 additions & 8 deletions src/test/unit/mcmc/hmc/xhmc/softabs_xhmc_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,13 @@ TEST(McmcUnitEXHMC, build_tree) {
EXPECT_FLOAT_EQ(1.5019561, sampler.z().p(1));
EXPECT_FLOAT_EQ(-1.5019561, sampler.z().p(2));

EXPECT_FLOAT_EQ(0.42903179, z_propose.q(0));
EXPECT_FLOAT_EQ(-0.42903179, z_propose.q(1));
EXPECT_FLOAT_EQ(0.42903179, z_propose.q(2));
EXPECT_FLOAT_EQ(0.8330583, z_propose.q(0));
EXPECT_FLOAT_EQ(-0.8330583, z_propose.q(1));
EXPECT_FLOAT_EQ(0.8330583, z_propose.q(2));

EXPECT_FLOAT_EQ(-1.4385087, z_propose.p(0));
EXPECT_FLOAT_EQ(1.4385087, z_propose.p(1));
EXPECT_FLOAT_EQ(-1.4385087, z_propose.p(2));
EXPECT_FLOAT_EQ(-1.1836562, z_propose.p(0));
EXPECT_FLOAT_EQ(1.1836562, z_propose.p(1));
EXPECT_FLOAT_EQ(-1.1836562, z_propose.p(2));

EXPECT_EQ(8, n_leapfrog);
EXPECT_FLOAT_EQ(3.7645235, ave);
Expand All @@ -79,7 +79,7 @@ TEST(McmcUnitEXHMC, build_tree) {
}

TEST(McmcUnitEXHMC, transition) {
stan::rng_t base_rng = stan::services::util::create_rng(483294, 0);
stan::rng_t base_rng = stan::services::util::create_rng(4832942, 0);

stan::mcmc::softabs_point z_init(3);
z_init.q(0) = 1;
Expand Down Expand Up @@ -112,7 +112,7 @@ TEST(McmcUnitEXHMC, transition) {
EXPECT_FLOAT_EQ(-1, s.cont_params()(1));
EXPECT_FLOAT_EQ(1, s.cont_params()(2));
EXPECT_FLOAT_EQ(-1.5, s.log_prob());
EXPECT_FLOAT_EQ(0.99993229, s.accept_stat());
EXPECT_FLOAT_EQ(0.99870497, s.accept_stat());

EXPECT_EQ("", debug.str());
EXPECT_EQ("", info.str());
Expand Down
16 changes: 8 additions & 8 deletions src/test/unit/mcmc/hmc/xhmc/unit_e_xhmc_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,13 @@ TEST(McmcUnitEXHMC, build_tree) {
EXPECT_FLOAT_EQ(1.4131583, sampler.z().p(1));
EXPECT_FLOAT_EQ(-1.4131583, sampler.z().p(2));

EXPECT_FLOAT_EQ(0.65928948, z_propose.q(0));
EXPECT_FLOAT_EQ(-0.65928948, z_propose.q(1));
EXPECT_FLOAT_EQ(0.65928948, z_propose.q(2));
EXPECT_FLOAT_EQ(0.11940599, z_propose.q(0));
EXPECT_FLOAT_EQ(-0.11940599, z_propose.q(1));
EXPECT_FLOAT_EQ(0.11940599, z_propose.q(2));

EXPECT_FLOAT_EQ(-1.2505695, z_propose.p(0));
EXPECT_FLOAT_EQ(1.2505695, z_propose.p(1));
EXPECT_FLOAT_EQ(-1.2505695, z_propose.p(2));
EXPECT_FLOAT_EQ(-1.408289, z_propose.p(0));
EXPECT_FLOAT_EQ(1.408289, z_propose.p(1));
EXPECT_FLOAT_EQ(-1.408289, z_propose.p(2));

EXPECT_EQ(8, n_leapfrog);
EXPECT_FLOAT_EQ(4.2207355, ave);
Expand All @@ -79,7 +79,7 @@ TEST(McmcUnitEXHMC, build_tree) {
}

TEST(McmcUnitEXHMC, transition) {
stan::rng_t base_rng = stan::services::util::create_rng(483294, 0);
stan::rng_t base_rng = stan::services::util::create_rng(4832942, 0);

stan::mcmc::unit_e_point z_init(3);
z_init.q(0) = 1;
Expand Down Expand Up @@ -112,7 +112,7 @@ TEST(McmcUnitEXHMC, transition) {
EXPECT_FLOAT_EQ(-1, s.cont_params()(1));
EXPECT_FLOAT_EQ(1, s.cont_params()(2));
EXPECT_FLOAT_EQ(-1.5, s.log_prob());
EXPECT_FLOAT_EQ(0.99994934, s.accept_stat());
EXPECT_FLOAT_EQ(0.99870926, s.accept_stat());
EXPECT_EQ("", debug.str());
EXPECT_EQ("", info.str());
EXPECT_EQ("", warn.str());
Expand Down

0 comments on commit 356d206

Please sign in to comment.