diff --git a/src/stan/mcmc/hmc/base_hmc.hpp b/src/stan/mcmc/hmc/base_hmc.hpp index d170ff4f6b..80c26f7595 100644 --- a/src/stan/mcmc/hmc/base_hmc.hpp +++ b/src/stan/mcmc/hmc/base_hmc.hpp @@ -70,6 +70,7 @@ class base_hmc : public base_mcmc { void write_sampler_state_struct(callbacks::structured_writer& struct_writer) { struct_writer.begin_record(); struct_writer.write("stepsize", get_nominal_stepsize()); + struct_writer.write("metric_type", z_.metric_type()); struct_writer.write("inv_metric", z_.inv_e_metric_); struct_writer.end_record(); } diff --git a/src/stan/mcmc/hmc/hamiltonians/dense_e_point.hpp b/src/stan/mcmc/hmc/hamiltonians/dense_e_point.hpp index 3b811e280e..ac5b4c3eaf 100644 --- a/src/stan/mcmc/hmc/hamiltonians/dense_e_point.hpp +++ b/src/stan/mcmc/hmc/hamiltonians/dense_e_point.hpp @@ -51,6 +51,8 @@ class dense_e_point : public ps_point { writer(inv_e_metric_ss.str()); } } + + inline std::string metric_type() { return "dense_e"; } }; } // namespace mcmc diff --git a/src/stan/mcmc/hmc/hamiltonians/diag_e_point.hpp b/src/stan/mcmc/hmc/hamiltonians/diag_e_point.hpp index eb2c2d9f5d..f5a15de2cb 100644 --- a/src/stan/mcmc/hmc/hamiltonians/diag_e_point.hpp +++ b/src/stan/mcmc/hmc/hamiltonians/diag_e_point.hpp @@ -49,6 +49,8 @@ class diag_e_point : public ps_point { inv_e_metric_ss << ", " << inv_e_metric_(i); writer(inv_e_metric_ss.str()); } + + inline std::string metric_type() { return "diag_e"; } }; } // namespace mcmc diff --git a/src/stan/mcmc/hmc/hamiltonians/softabs_point.hpp b/src/stan/mcmc/hmc/hamiltonians/softabs_point.hpp index bddd225f9f..b782e2a09c 100644 --- a/src/stan/mcmc/hmc/hamiltonians/softabs_point.hpp +++ b/src/stan/mcmc/hmc/hamiltonians/softabs_point.hpp @@ -43,6 +43,8 @@ class softabs_point : public ps_point { virtual inline void write_metric(stan::callbacks::writer& writer) { writer("No free parameters for SoftAbs metric"); } + + inline std::string metric_type() { return "softabs"; } }; } // namespace mcmc diff --git a/src/stan/mcmc/hmc/hamiltonians/unit_e_point.hpp b/src/stan/mcmc/hmc/hamiltonians/unit_e_point.hpp index cb6b439409..c4e9933464 100644 --- a/src/stan/mcmc/hmc/hamiltonians/unit_e_point.hpp +++ b/src/stan/mcmc/hmc/hamiltonians/unit_e_point.hpp @@ -30,6 +30,8 @@ class unit_e_point : public ps_point { inline void write_metric(stan::callbacks::writer& writer) { writer("No free parameters for unit metric"); } + + inline std::string metric_type() { return "unit_e"; } }; } // namespace mcmc diff --git a/src/test/unit/services/sample/hmc_nuts_dense_e_adapt_parallel_match_test.cpp b/src/test/unit/services/sample/hmc_nuts_dense_e_adapt_parallel_match_test.cpp index 1250f05dc9..f98c59a820 100644 --- a/src/test/unit/services/sample/hmc_nuts_dense_e_adapt_parallel_match_test.cpp +++ b/src/test/unit/services/sample/hmc_nuts_dense_e_adapt_parallel_match_test.cpp @@ -125,6 +125,7 @@ TEST_F(ServicesSampleHmcNutsDenseEAdaptParMatch, single_multi_match) { par_metrics.push_back(ss_metric[i].str()); ASSERT_TRUE(stan::test::is_valid_JSON(par_metrics[i])); EXPECT_EQ(count_matches("stepsize", par_metrics[i]), 1); + EXPECT_EQ(count_matches("metric_type", par_metrics[i]), 1); EXPECT_EQ(count_matches("inv_metric", par_metrics[i]), 1); EXPECT_EQ(count_matches("[", par_metrics[i]), 3); // list has 2 rows } diff --git a/src/test/unit/services/sample/hmc_nuts_diag_e_adapt_parallel_match_test.cpp b/src/test/unit/services/sample/hmc_nuts_diag_e_adapt_parallel_match_test.cpp index 50da61b09f..ca3b326c54 100644 --- a/src/test/unit/services/sample/hmc_nuts_diag_e_adapt_parallel_match_test.cpp +++ b/src/test/unit/services/sample/hmc_nuts_diag_e_adapt_parallel_match_test.cpp @@ -124,6 +124,7 @@ TEST_F(ServicesSampleHmcNutsDiagEAdaptParMatch, single_multi_match) { par_metrics.push_back(ss_metric[i].str()); ASSERT_TRUE(stan::test::is_valid_JSON(par_metrics[i])); EXPECT_EQ(count_matches("stepsize", par_metrics[i]), 1); + EXPECT_EQ(count_matches("metric_type", par_metrics[i]), 1); EXPECT_EQ(count_matches("inv_metric", par_metrics[i]), 1); EXPECT_EQ(count_matches("[", par_metrics[i]), 1); // single list } diff --git a/src/test/unit/services/sample/hmc_nuts_unit_e_adapt_parallel_test.cpp b/src/test/unit/services/sample/hmc_nuts_unit_e_adapt_parallel_test.cpp index 5574fa685a..2889c6b229 100644 --- a/src/test/unit/services/sample/hmc_nuts_unit_e_adapt_parallel_test.cpp +++ b/src/test/unit/services/sample/hmc_nuts_unit_e_adapt_parallel_test.cpp @@ -128,6 +128,7 @@ TEST_F(ServicesSampleHmcNutsUnitEAdaptPar, parameter_checks) { // Adapted metric ASSERT_TRUE(stan::test::is_valid_JSON(metric)); EXPECT_EQ(count_matches("stepsize", metric), 1); + EXPECT_EQ(count_matches("metric_type", metric), 1); EXPECT_EQ(count_matches("inv_metric", metric), 1); EXPECT_EQ(count_matches("[", metric), 1); // single list EXPECT_EQ(count_matches("[ 1, 1 ]", metric), 1); // unit diagonal