Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend get_dims and get_param_names in model_base #3139

Merged
merged 10 commits into from
Mar 9, 2023
18 changes: 14 additions & 4 deletions src/stan/model/model_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,14 @@ class model_base : public prob_grad {
*
* @param[in,out] names sequence of names parameters, transformed
* parameters, and generated quantities
* @param[in] include_tparams true if transformed parameters should
* be included
* @param[in] include_gqs true if generated quantities should be
* included
*/
virtual void get_param_names(std::vector<std::string>& names) const = 0;

virtual void get_param_names(std::vector<std::string>& names,
bool include_tparams = true,
bool include_gqs = true) const = 0;
/**
* Set the dimensionalities of constrained parameters, transformed
* parameters, and generated quantities. The input sequence is
Expand All @@ -90,9 +95,14 @@ class model_base : public prob_grad {
* dimensionality `std::vector<size_t>{2, 3, 4}`.
*
* @param[in,out] dimss sequence of dimensions specifications to set
* @param[in] include_tparams true if transformed parameters should
* be included
* @param[in] include_gqs true if generated quantities should be
* included
*/
virtual void get_dims(std::vector<std::vector<size_t> >& dimss) const = 0;

virtual void get_dims(std::vector<std::vector<size_t> >& dimss,
bool include_tparams = true,
bool include_gqs = true) const = 0;
/**
* Set the specified sequence to the indexed, scalar, constrained
* parameter names. Each variable is output with a
Expand Down
28 changes: 8 additions & 20 deletions src/stan/services/sample/standalone_gqs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include <stan/services/error_codes.hpp>
#include <stan/services/util/create_rng.hpp>
#include <stan/services/util/gq_writer.hpp>
#include <boost/algorithm/string.hpp>
#include <iostream>
#include <string>
#include <vector>
Expand All @@ -31,27 +30,16 @@ template <class Model>
void get_model_parameters(const Model &model,
std::vector<std::string> &param_names,
std::vector<std::vector<size_t>> &param_dimss) {
std::vector<std::string> param_cols;
model.constrained_param_names(param_cols, false, false);
std::string cur_name("");
std::vector<std::string> splits;
for (size_t i = 0; i < param_cols.size(); ++i) {
boost::algorithm::split(splits, param_cols[i], boost::is_any_of("."));
if (splits.size() == 1 || splits[0] != cur_name) {
cur_name = splits[0];
param_names.emplace_back(cur_name);
}
}
std::vector<std::string> all_param_names;
model.get_param_names(all_param_names);
model.get_param_names(all_param_names, false, false);
std::vector<std::vector<size_t>> dimss;
model.get_dims(dimss);
for (size_t i = 0; i < param_names.size(); i++) {
for (size_t j = i; j < all_param_names.size(); ++j) {
if (param_names[i].compare(all_param_names[j]) == 0) {
param_dimss.emplace_back(dimss[j]);
break;
}
model.get_dims(dimss, false, false);
// remove zero-size
for (size_t i = 0; i < all_param_names.size(); i++) {
auto &v = dimss[i];
if (std::find(v.begin(), v.end(), 0) == v.end()) {
param_names.emplace_back(all_param_names[i]);
param_dimss.emplace_back(dimss[i]);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/stan/services/util/initialize.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ std::vector<double> initialize(Model& model, const InitContext& init, RNG& rng,
bool is_fully_initialized = true;
bool any_initialized = false;
std::vector<std::string> param_names;
model.get_param_names(param_names);
model.get_param_names(param_names, false, false);
for (size_t n = 0; n < param_names.size(); n++) {
is_fully_initialized &= init.contains_r(param_names[n]);
any_initialized |= init.contains_r(param_names[n]);
Expand Down
6 changes: 4 additions & 2 deletions src/test/unit/model/model_base_crtp_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ struct mock_model : public stan::model::model_base_crtp<mock_model> {
return stanc_info;
}

void get_param_names(std::vector<std::string>& names) const override {}
void get_dims(std::vector<std::vector<size_t> >& dimss) const override {}
void get_param_names(std::vector<std::string>& names, bool include_tparams,
bool include_gqs) const override {}
void get_dims(std::vector<std::vector<size_t> >& dimss, bool include_tparams,
bool include_gqs) const override {}

void constrained_param_names(std::vector<std::string>& param_names,
bool include_tparams,
Expand Down
6 changes: 4 additions & 2 deletions src/test/unit/model/model_base_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ struct mock_model : public stan::model::model_base {
return stanc_info;
}

void get_param_names(std::vector<std::string>& names) const override {}
void get_dims(std::vector<std::vector<size_t> >& dimss) const override {}
void get_param_names(std::vector<std::string>& names, bool include_tparams,
bool include_gqs) const override {}
void get_dims(std::vector<std::vector<size_t> >& dimss, bool include_tparams,
bool include_gqs) const override {}

void constrained_param_names(std::vector<std::string>& param_names,
bool include_tparams,
Expand Down
21 changes: 15 additions & 6 deletions src/test/unit/services/util/initialize_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ class mock_throwing_model : public stan::model::prob_grad {
}
}

void get_dims(std::vector<std::vector<size_t> >& dimss__) const {
void get_dims(std::vector<std::vector<size_t> >& dimss__,
bool include_tparams = true, bool include_gqs = true) const {
dimss__.resize(0);
std::vector<size_t> scalar_dim;
dimss__.push_back(scalar_dim);
Expand All @@ -213,7 +214,9 @@ class mock_throwing_model : public stan::model::prob_grad {
param_names__.push_back("theta");
}

void get_param_names(std::vector<std::string>& names) const {
void get_param_names(std::vector<std::string>& names,
bool include_tparams = true,
bool include_gqs = true) const {
constrained_param_names(names);
}

Expand Down Expand Up @@ -336,7 +339,8 @@ class mock_error_model : public stan::model::prob_grad {
}
}

void get_dims(std::vector<std::vector<size_t> >& dimss__) const {
void get_dims(std::vector<std::vector<size_t> >& dimss__,
bool include_tparams = true, bool include_gqs = true) const {
dimss__.resize(0);
std::vector<size_t> scalar_dim;
dimss__.push_back(scalar_dim);
Expand All @@ -348,7 +352,9 @@ class mock_error_model : public stan::model::prob_grad {
param_names__.push_back("theta");
}

void get_param_names(std::vector<std::string>& names) const {
void get_param_names(std::vector<std::string>& names,
bool include_tparams = true,
bool include_gqs = true) const {
constrained_param_names(names);
}

Expand Down Expand Up @@ -471,7 +477,8 @@ class mock_throwing_model_in_write_array : public stan::model::prob_grad {
}
}

void get_dims(std::vector<std::vector<size_t> >& dimss__) const {
void get_dims(std::vector<std::vector<size_t> >& dimss__,
bool include_tparams = true, bool include_gqs = true) const {
dimss__.resize(0);
std::vector<size_t> scalar_dim;
dimss__.push_back(scalar_dim);
Expand All @@ -483,7 +490,9 @@ class mock_throwing_model_in_write_array : public stan::model::prob_grad {
param_names__.push_back("theta");
}

void get_param_names(std::vector<std::string>& names) const {
void get_param_names(std::vector<std::string>& names,
bool include_tparams = true,
bool include_gqs = true) const {
constrained_param_names(names);
}

Expand Down
8 changes: 6 additions & 2 deletions src/test/unit/services/util/mcmc_writer_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,18 @@ class throwing_model : public stan::model::model_base_crtp<throwing_model> {

} // transform_inits()

inline void get_param_names(std::vector<std::string>& names__) const {
inline void get_param_names(std::vector<std::string>& names__,
bool include_tparams = true,
bool include_gqs = true) const {
names__.clear();
names__.emplace_back("y");
names__.emplace_back("z");
names__.emplace_back("xgq");
} // get_param_names()

inline void get_dims(std::vector<std::vector<size_t>>& dimss__) const final {
inline void get_dims(std::vector<std::vector<size_t>>& dimss__,
bool include_tparams = true,
bool include_gqs = true) const final {
dimss__.clear();
dimss__.emplace_back(std::vector<size_t>{static_cast<size_t>(2)});

Expand Down
14 changes: 10 additions & 4 deletions src/test/unit/variational/eta_adapt_mock_models_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ class mock_model : public stan::model::prob_grad {
}
}

void get_dims(std::vector<std::vector<size_t> >& dimss__) const {
void get_dims(std::vector<std::vector<size_t> >& dimss__,
bool include_tparams = true, bool include_gqs = true) const {
dimss__.resize(0);
std::vector<size_t> scalar_dim;
dimss__.push_back(scalar_dim);
Expand All @@ -59,7 +60,9 @@ class mock_model : public stan::model::prob_grad {
param_names__.push_back("c");
}

void get_param_names(std::vector<std::string>& names) const {
void get_param_names(std::vector<std::string>& names,
bool include_tparams = true,
bool include_gqs = true) const {
constrained_param_names(names);
}

Expand Down Expand Up @@ -124,7 +127,8 @@ class mock_throwing_model : public stan::model::prob_grad {
}
}

void get_dims(std::vector<std::vector<size_t> >& dimss__) const {
void get_dims(std::vector<std::vector<size_t> >& dimss__,
bool include_tparams = true, bool include_gqs = true) const {
dimss__.resize(0);
std::vector<size_t> scalar_dim;
dimss__.push_back(scalar_dim);
Expand All @@ -140,7 +144,9 @@ class mock_throwing_model : public stan::model::prob_grad {
param_names__.push_back("c");
}

void get_param_names(std::vector<std::string>& names) const {
void get_param_names(std::vector<std::string>& names,
bool include_tparams = true,
bool include_gqs = true) const {
constrained_param_names(names);
}

Expand Down
14 changes: 10 additions & 4 deletions src/test/unit/variational/stochastic_gradient_ascent_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ class mock_model : public stan::model::prob_grad {
}
}

void get_dims(std::vector<std::vector<size_t> >& dimss__) const {
void get_dims(std::vector<std::vector<size_t> >& dimss__,
bool include_tparams = true, bool include_gqs = true) const {
dimss__.resize(0);
std::vector<size_t> scalar_dim;
dimss__.push_back(scalar_dim);
Expand All @@ -60,7 +61,9 @@ class mock_model : public stan::model::prob_grad {
param_names__.push_back("c");
}

void get_param_names(std::vector<std::string>& names) const {
void get_param_names(std::vector<std::string>& names,
bool include_tparams = true,
bool include_gqs = true) const {
constrained_param_names(names);
}

Expand Down Expand Up @@ -125,7 +128,8 @@ class mock_throwing_model : public stan::model::prob_grad {
}
}

void get_dims(std::vector<std::vector<size_t> >& dimss__) const {
void get_dims(std::vector<std::vector<size_t> >& dimss__,
bool include_tparams = true, bool include_gqs = true) const {
dimss__.resize(0);
std::vector<size_t> scalar_dim;
dimss__.push_back(scalar_dim);
Expand All @@ -141,7 +145,9 @@ class mock_throwing_model : public stan::model::prob_grad {
param_names__.push_back("c");
}

void get_param_names(std::vector<std::string>& names) const {
void get_param_names(std::vector<std::string>& names,
bool include_tparams = true,
bool include_gqs = true) const {
constrained_param_names(names);
}

Expand Down