From 3591ded7c84ee786fc20a54763ef841b55282996 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Thu, 7 Dec 2023 17:12:30 -0500 Subject: [PATCH 1/4] report metric type based on point type --- src/stan/mcmc/hmc/base_hmc.hpp | 1 + src/stan/mcmc/hmc/hamiltonians/dense_e_point.hpp | 4 ++++ src/stan/mcmc/hmc/hamiltonians/diag_e_point.hpp | 4 ++++ src/stan/mcmc/hmc/hamiltonians/softabs_point.hpp | 4 ++++ src/stan/mcmc/hmc/hamiltonians/unit_e_point.hpp | 4 ++++ 5 files changed, 17 insertions(+) diff --git a/src/stan/mcmc/hmc/base_hmc.hpp b/src/stan/mcmc/hmc/base_hmc.hpp index d170ff4f6b3..9895b10f6ca 100644 --- a/src/stan/mcmc/hmc/base_hmc.hpp +++ b/src/stan/mcmc/hmc/base_hmc.hpp @@ -71,6 +71,7 @@ class base_hmc : public base_mcmc { struct_writer.begin_record(); struct_writer.write("stepsize", get_nominal_stepsize()); struct_writer.write("inv_metric", z_.inv_e_metric_); + struct_writer.write("metric_type", z_.metric_type()); struct_writer.end_record(); } diff --git a/src/stan/mcmc/hmc/hamiltonians/dense_e_point.hpp b/src/stan/mcmc/hmc/hamiltonians/dense_e_point.hpp index 3b811e280e3..35590fc5974 100644 --- a/src/stan/mcmc/hmc/hamiltonians/dense_e_point.hpp +++ b/src/stan/mcmc/hmc/hamiltonians/dense_e_point.hpp @@ -51,6 +51,10 @@ class dense_e_point : public ps_point { writer(inv_e_metric_ss.str()); } } + + inline std::string metric_type() { + return "dense"; + } }; } // namespace mcmc diff --git a/src/stan/mcmc/hmc/hamiltonians/diag_e_point.hpp b/src/stan/mcmc/hmc/hamiltonians/diag_e_point.hpp index eb2c2d9f5d5..5fd58b42385 100644 --- a/src/stan/mcmc/hmc/hamiltonians/diag_e_point.hpp +++ b/src/stan/mcmc/hmc/hamiltonians/diag_e_point.hpp @@ -49,6 +49,10 @@ class diag_e_point : public ps_point { inv_e_metric_ss << ", " << inv_e_metric_(i); writer(inv_e_metric_ss.str()); } + + inline std::string metric_type() { + return "diag"; + } }; } // namespace mcmc diff --git a/src/stan/mcmc/hmc/hamiltonians/softabs_point.hpp b/src/stan/mcmc/hmc/hamiltonians/softabs_point.hpp index bddd225f9f8..d5ba1e01aee 100644 --- a/src/stan/mcmc/hmc/hamiltonians/softabs_point.hpp +++ b/src/stan/mcmc/hmc/hamiltonians/softabs_point.hpp @@ -43,6 +43,10 @@ class softabs_point : public ps_point { virtual inline void write_metric(stan::callbacks::writer& writer) { writer("No free parameters for SoftAbs metric"); } + + inline std::string metric_type() { + return "softabs"; + } }; } // namespace mcmc diff --git a/src/stan/mcmc/hmc/hamiltonians/unit_e_point.hpp b/src/stan/mcmc/hmc/hamiltonians/unit_e_point.hpp index cb6b439409c..923f588eaca 100644 --- a/src/stan/mcmc/hmc/hamiltonians/unit_e_point.hpp +++ b/src/stan/mcmc/hmc/hamiltonians/unit_e_point.hpp @@ -30,6 +30,10 @@ class unit_e_point : public ps_point { inline void write_metric(stan::callbacks::writer& writer) { writer("No free parameters for unit metric"); } + + inline std::string metric_type() { + return "unit"; + } }; } // namespace mcmc From cebb8bb85eb0d16b3518bbc602a327f7cc0a23ba Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Thu, 7 Dec 2023 17:30:04 -0500 Subject: [PATCH 2/4] metric types match cmdstan args --- src/stan/mcmc/hmc/hamiltonians/dense_e_point.hpp | 2 +- src/stan/mcmc/hmc/hamiltonians/diag_e_point.hpp | 2 +- src/stan/mcmc/hmc/hamiltonians/unit_e_point.hpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/stan/mcmc/hmc/hamiltonians/dense_e_point.hpp b/src/stan/mcmc/hmc/hamiltonians/dense_e_point.hpp index 35590fc5974..b5497f27a09 100644 --- a/src/stan/mcmc/hmc/hamiltonians/dense_e_point.hpp +++ b/src/stan/mcmc/hmc/hamiltonians/dense_e_point.hpp @@ -53,7 +53,7 @@ class dense_e_point : public ps_point { } inline std::string metric_type() { - return "dense"; + return "dense_e"; } }; diff --git a/src/stan/mcmc/hmc/hamiltonians/diag_e_point.hpp b/src/stan/mcmc/hmc/hamiltonians/diag_e_point.hpp index 5fd58b42385..70bfe710aeb 100644 --- a/src/stan/mcmc/hmc/hamiltonians/diag_e_point.hpp +++ b/src/stan/mcmc/hmc/hamiltonians/diag_e_point.hpp @@ -51,7 +51,7 @@ class diag_e_point : public ps_point { } inline std::string metric_type() { - return "diag"; + return "diag_e"; } }; diff --git a/src/stan/mcmc/hmc/hamiltonians/unit_e_point.hpp b/src/stan/mcmc/hmc/hamiltonians/unit_e_point.hpp index 923f588eaca..5282b3b3ba2 100644 --- a/src/stan/mcmc/hmc/hamiltonians/unit_e_point.hpp +++ b/src/stan/mcmc/hmc/hamiltonians/unit_e_point.hpp @@ -32,7 +32,7 @@ class unit_e_point : public ps_point { } inline std::string metric_type() { - return "unit"; + return "unit_e"; } }; From a4c1c5e1591b47678eed2f35d9701d10929a6931 Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Fri, 8 Dec 2023 16:05:41 -0500 Subject: [PATCH 3/4] [Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1 --- src/stan/mcmc/hmc/hamiltonians/dense_e_point.hpp | 4 +--- src/stan/mcmc/hmc/hamiltonians/diag_e_point.hpp | 4 +--- src/stan/mcmc/hmc/hamiltonians/softabs_point.hpp | 4 +--- src/stan/mcmc/hmc/hamiltonians/unit_e_point.hpp | 4 +--- 4 files changed, 4 insertions(+), 12 deletions(-) diff --git a/src/stan/mcmc/hmc/hamiltonians/dense_e_point.hpp b/src/stan/mcmc/hmc/hamiltonians/dense_e_point.hpp index b5497f27a09..ac5b4c3eafd 100644 --- a/src/stan/mcmc/hmc/hamiltonians/dense_e_point.hpp +++ b/src/stan/mcmc/hmc/hamiltonians/dense_e_point.hpp @@ -52,9 +52,7 @@ class dense_e_point : public ps_point { } } - inline std::string metric_type() { - return "dense_e"; - } + inline std::string metric_type() { return "dense_e"; } }; } // namespace mcmc diff --git a/src/stan/mcmc/hmc/hamiltonians/diag_e_point.hpp b/src/stan/mcmc/hmc/hamiltonians/diag_e_point.hpp index 70bfe710aeb..f5a15de2cba 100644 --- a/src/stan/mcmc/hmc/hamiltonians/diag_e_point.hpp +++ b/src/stan/mcmc/hmc/hamiltonians/diag_e_point.hpp @@ -50,9 +50,7 @@ class diag_e_point : public ps_point { writer(inv_e_metric_ss.str()); } - inline std::string metric_type() { - return "diag_e"; - } + inline std::string metric_type() { return "diag_e"; } }; } // namespace mcmc diff --git a/src/stan/mcmc/hmc/hamiltonians/softabs_point.hpp b/src/stan/mcmc/hmc/hamiltonians/softabs_point.hpp index d5ba1e01aee..b782e2a09cf 100644 --- a/src/stan/mcmc/hmc/hamiltonians/softabs_point.hpp +++ b/src/stan/mcmc/hmc/hamiltonians/softabs_point.hpp @@ -44,9 +44,7 @@ class softabs_point : public ps_point { writer("No free parameters for SoftAbs metric"); } - inline std::string metric_type() { - return "softabs"; - } + inline std::string metric_type() { return "softabs"; } }; } // namespace mcmc diff --git a/src/stan/mcmc/hmc/hamiltonians/unit_e_point.hpp b/src/stan/mcmc/hmc/hamiltonians/unit_e_point.hpp index 5282b3b3ba2..c4e99334648 100644 --- a/src/stan/mcmc/hmc/hamiltonians/unit_e_point.hpp +++ b/src/stan/mcmc/hmc/hamiltonians/unit_e_point.hpp @@ -31,9 +31,7 @@ class unit_e_point : public ps_point { writer("No free parameters for unit metric"); } - inline std::string metric_type() { - return "unit_e"; - } + inline std::string metric_type() { return "unit_e"; } }; } // namespace mcmc From 1af6cb25341878977e539047b912146cc7542e2d Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Sat, 9 Dec 2023 12:57:51 -0500 Subject: [PATCH 4/4] changes per code review --- src/stan/mcmc/hmc/base_hmc.hpp | 2 +- .../sample/hmc_nuts_dense_e_adapt_parallel_match_test.cpp | 1 + .../sample/hmc_nuts_diag_e_adapt_parallel_match_test.cpp | 1 + .../services/sample/hmc_nuts_unit_e_adapt_parallel_test.cpp | 1 + 4 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/stan/mcmc/hmc/base_hmc.hpp b/src/stan/mcmc/hmc/base_hmc.hpp index 9895b10f6ca..80c26f75957 100644 --- a/src/stan/mcmc/hmc/base_hmc.hpp +++ b/src/stan/mcmc/hmc/base_hmc.hpp @@ -70,8 +70,8 @@ class base_hmc : public base_mcmc { void write_sampler_state_struct(callbacks::structured_writer& struct_writer) { struct_writer.begin_record(); struct_writer.write("stepsize", get_nominal_stepsize()); - struct_writer.write("inv_metric", z_.inv_e_metric_); struct_writer.write("metric_type", z_.metric_type()); + struct_writer.write("inv_metric", z_.inv_e_metric_); struct_writer.end_record(); } diff --git a/src/test/unit/services/sample/hmc_nuts_dense_e_adapt_parallel_match_test.cpp b/src/test/unit/services/sample/hmc_nuts_dense_e_adapt_parallel_match_test.cpp index 1250f05dc91..f98c59a820a 100644 --- a/src/test/unit/services/sample/hmc_nuts_dense_e_adapt_parallel_match_test.cpp +++ b/src/test/unit/services/sample/hmc_nuts_dense_e_adapt_parallel_match_test.cpp @@ -125,6 +125,7 @@ TEST_F(ServicesSampleHmcNutsDenseEAdaptParMatch, single_multi_match) { par_metrics.push_back(ss_metric[i].str()); ASSERT_TRUE(stan::test::is_valid_JSON(par_metrics[i])); EXPECT_EQ(count_matches("stepsize", par_metrics[i]), 1); + EXPECT_EQ(count_matches("metric_type", par_metrics[i]), 1); EXPECT_EQ(count_matches("inv_metric", par_metrics[i]), 1); EXPECT_EQ(count_matches("[", par_metrics[i]), 3); // list has 2 rows } diff --git a/src/test/unit/services/sample/hmc_nuts_diag_e_adapt_parallel_match_test.cpp b/src/test/unit/services/sample/hmc_nuts_diag_e_adapt_parallel_match_test.cpp index 50da61b09fb..ca3b326c54f 100644 --- a/src/test/unit/services/sample/hmc_nuts_diag_e_adapt_parallel_match_test.cpp +++ b/src/test/unit/services/sample/hmc_nuts_diag_e_adapt_parallel_match_test.cpp @@ -124,6 +124,7 @@ TEST_F(ServicesSampleHmcNutsDiagEAdaptParMatch, single_multi_match) { par_metrics.push_back(ss_metric[i].str()); ASSERT_TRUE(stan::test::is_valid_JSON(par_metrics[i])); EXPECT_EQ(count_matches("stepsize", par_metrics[i]), 1); + EXPECT_EQ(count_matches("metric_type", par_metrics[i]), 1); EXPECT_EQ(count_matches("inv_metric", par_metrics[i]), 1); EXPECT_EQ(count_matches("[", par_metrics[i]), 1); // single list } diff --git a/src/test/unit/services/sample/hmc_nuts_unit_e_adapt_parallel_test.cpp b/src/test/unit/services/sample/hmc_nuts_unit_e_adapt_parallel_test.cpp index 5574fa685a6..2889c6b2295 100644 --- a/src/test/unit/services/sample/hmc_nuts_unit_e_adapt_parallel_test.cpp +++ b/src/test/unit/services/sample/hmc_nuts_unit_e_adapt_parallel_test.cpp @@ -128,6 +128,7 @@ TEST_F(ServicesSampleHmcNutsUnitEAdaptPar, parameter_checks) { // Adapted metric ASSERT_TRUE(stan::test::is_valid_JSON(metric)); EXPECT_EQ(count_matches("stepsize", metric), 1); + EXPECT_EQ(count_matches("metric_type", metric), 1); EXPECT_EQ(count_matches("inv_metric", metric), 1); EXPECT_EQ(count_matches("[", metric), 1); // single list EXPECT_EQ(count_matches("[ 1, 1 ]", metric), 1); // unit diagonal