Skip to content

Commit

Permalink
Merge pull request #3139 from stan-dev/feature/param-names-dims-flags
Browse files Browse the repository at this point in the history
Extend `get_dims` and `get_param_names` in model_base
  • Loading branch information
WardBrian committed Mar 9, 2023
2 parents f9fa778 + f3cddca commit 5118264
Show file tree
Hide file tree
Showing 9 changed files with 72 additions and 45 deletions.
18 changes: 14 additions & 4 deletions src/stan/model/model_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,14 @@ class model_base : public prob_grad {
*
* @param[in,out] names sequence of names parameters, transformed
* parameters, and generated quantities
* @param[in] include_tparams true if transformed parameters should
* be included
* @param[in] include_gqs true if generated quantities should be
* included
*/
virtual void get_param_names(std::vector<std::string>& names) const = 0;

virtual void get_param_names(std::vector<std::string>& names,
bool include_tparams = true,
bool include_gqs = true) const = 0;
/**
* Set the dimensionalities of constrained parameters, transformed
* parameters, and generated quantities. The input sequence is
Expand All @@ -90,9 +95,14 @@ class model_base : public prob_grad {
* dimensionality `std::vector<size_t>{2, 3, 4}`.
*
* @param[in,out] dimss sequence of dimensions specifications to set
* @param[in] include_tparams true if transformed parameters should
* be included
* @param[in] include_gqs true if generated quantities should be
* included
*/
virtual void get_dims(std::vector<std::vector<size_t> >& dimss) const = 0;

virtual void get_dims(std::vector<std::vector<size_t> >& dimss,
bool include_tparams = true,
bool include_gqs = true) const = 0;
/**
* Set the specified sequence to the indexed, scalar, constrained
* parameter names. Each variable is output with a
Expand Down
28 changes: 8 additions & 20 deletions src/stan/services/sample/standalone_gqs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include <stan/services/error_codes.hpp>
#include <stan/services/util/create_rng.hpp>
#include <stan/services/util/gq_writer.hpp>
#include <boost/algorithm/string.hpp>
#include <iostream>
#include <string>
#include <vector>
Expand All @@ -31,27 +30,16 @@ template <class Model>
void get_model_parameters(const Model &model,
std::vector<std::string> &param_names,
std::vector<std::vector<size_t>> &param_dimss) {
std::vector<std::string> param_cols;
model.constrained_param_names(param_cols, false, false);
std::string cur_name("");
std::vector<std::string> splits;
for (size_t i = 0; i < param_cols.size(); ++i) {
boost::algorithm::split(splits, param_cols[i], boost::is_any_of("."));
if (splits.size() == 1 || splits[0] != cur_name) {
cur_name = splits[0];
param_names.emplace_back(cur_name);
}
}
std::vector<std::string> all_param_names;
model.get_param_names(all_param_names);
model.get_param_names(all_param_names, false, false);
std::vector<std::vector<size_t>> dimss;
model.get_dims(dimss);
for (size_t i = 0; i < param_names.size(); i++) {
for (size_t j = i; j < all_param_names.size(); ++j) {
if (param_names[i].compare(all_param_names[j]) == 0) {
param_dimss.emplace_back(dimss[j]);
break;
}
model.get_dims(dimss, false, false);
// remove zero-size
for (size_t i = 0; i < all_param_names.size(); i++) {
auto &v = dimss[i];
if (std::find(v.begin(), v.end(), 0) == v.end()) {
param_names.emplace_back(all_param_names[i]);
param_dimss.emplace_back(dimss[i]);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/stan/services/util/initialize.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ std::vector<double> initialize(Model& model, const InitContext& init, RNG& rng,
bool is_fully_initialized = true;
bool any_initialized = false;
std::vector<std::string> param_names;
model.get_param_names(param_names);
model.get_param_names(param_names, false, false);
for (size_t n = 0; n < param_names.size(); n++) {
is_fully_initialized &= init.contains_r(param_names[n]);
any_initialized |= init.contains_r(param_names[n]);
Expand Down
6 changes: 4 additions & 2 deletions src/test/unit/model/model_base_crtp_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ struct mock_model : public stan::model::model_base_crtp<mock_model> {
return stanc_info;
}

void get_param_names(std::vector<std::string>& names) const override {}
void get_dims(std::vector<std::vector<size_t> >& dimss) const override {}
void get_param_names(std::vector<std::string>& names, bool include_tparams,
bool include_gqs) const override {}
void get_dims(std::vector<std::vector<size_t> >& dimss, bool include_tparams,
bool include_gqs) const override {}

void constrained_param_names(std::vector<std::string>& param_names,
bool include_tparams,
Expand Down
6 changes: 4 additions & 2 deletions src/test/unit/model/model_base_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ struct mock_model : public stan::model::model_base {
return stanc_info;
}

void get_param_names(std::vector<std::string>& names) const override {}
void get_dims(std::vector<std::vector<size_t> >& dimss) const override {}
void get_param_names(std::vector<std::string>& names, bool include_tparams,
bool include_gqs) const override {}
void get_dims(std::vector<std::vector<size_t> >& dimss, bool include_tparams,
bool include_gqs) const override {}

void constrained_param_names(std::vector<std::string>& param_names,
bool include_tparams,
Expand Down
21 changes: 15 additions & 6 deletions src/test/unit/services/util/initialize_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ class mock_throwing_model : public stan::model::prob_grad {
}
}

void get_dims(std::vector<std::vector<size_t> >& dimss__) const {
void get_dims(std::vector<std::vector<size_t> >& dimss__,
bool include_tparams = true, bool include_gqs = true) const {
dimss__.resize(0);
std::vector<size_t> scalar_dim;
dimss__.push_back(scalar_dim);
Expand All @@ -213,7 +214,9 @@ class mock_throwing_model : public stan::model::prob_grad {
param_names__.push_back("theta");
}

void get_param_names(std::vector<std::string>& names) const {
void get_param_names(std::vector<std::string>& names,
bool include_tparams = true,
bool include_gqs = true) const {
constrained_param_names(names);
}

Expand Down Expand Up @@ -336,7 +339,8 @@ class mock_error_model : public stan::model::prob_grad {
}
}

void get_dims(std::vector<std::vector<size_t> >& dimss__) const {
void get_dims(std::vector<std::vector<size_t> >& dimss__,
bool include_tparams = true, bool include_gqs = true) const {
dimss__.resize(0);
std::vector<size_t> scalar_dim;
dimss__.push_back(scalar_dim);
Expand All @@ -348,7 +352,9 @@ class mock_error_model : public stan::model::prob_grad {
param_names__.push_back("theta");
}

void get_param_names(std::vector<std::string>& names) const {
void get_param_names(std::vector<std::string>& names,
bool include_tparams = true,
bool include_gqs = true) const {
constrained_param_names(names);
}

Expand Down Expand Up @@ -471,7 +477,8 @@ class mock_throwing_model_in_write_array : public stan::model::prob_grad {
}
}

void get_dims(std::vector<std::vector<size_t> >& dimss__) const {
void get_dims(std::vector<std::vector<size_t> >& dimss__,
bool include_tparams = true, bool include_gqs = true) const {
dimss__.resize(0);
std::vector<size_t> scalar_dim;
dimss__.push_back(scalar_dim);
Expand All @@ -483,7 +490,9 @@ class mock_throwing_model_in_write_array : public stan::model::prob_grad {
param_names__.push_back("theta");
}

void get_param_names(std::vector<std::string>& names) const {
void get_param_names(std::vector<std::string>& names,
bool include_tparams = true,
bool include_gqs = true) const {
constrained_param_names(names);
}

Expand Down
8 changes: 6 additions & 2 deletions src/test/unit/services/util/mcmc_writer_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,18 @@ class throwing_model : public stan::model::model_base_crtp<throwing_model> {

} // transform_inits()

inline void get_param_names(std::vector<std::string>& names__) const {
inline void get_param_names(std::vector<std::string>& names__,
bool include_tparams = true,
bool include_gqs = true) const {
names__.clear();
names__.emplace_back("y");
names__.emplace_back("z");
names__.emplace_back("xgq");
} // get_param_names()

inline void get_dims(std::vector<std::vector<size_t>>& dimss__) const final {
inline void get_dims(std::vector<std::vector<size_t>>& dimss__,
bool include_tparams = true,
bool include_gqs = true) const final {
dimss__.clear();
dimss__.emplace_back(std::vector<size_t>{static_cast<size_t>(2)});

Expand Down
14 changes: 10 additions & 4 deletions src/test/unit/variational/eta_adapt_mock_models_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ class mock_model : public stan::model::prob_grad {
}
}

void get_dims(std::vector<std::vector<size_t> >& dimss__) const {
void get_dims(std::vector<std::vector<size_t> >& dimss__,
bool include_tparams = true, bool include_gqs = true) const {
dimss__.resize(0);
std::vector<size_t> scalar_dim;
dimss__.push_back(scalar_dim);
Expand All @@ -59,7 +60,9 @@ class mock_model : public stan::model::prob_grad {
param_names__.push_back("c");
}

void get_param_names(std::vector<std::string>& names) const {
void get_param_names(std::vector<std::string>& names,
bool include_tparams = true,
bool include_gqs = true) const {
constrained_param_names(names);
}

Expand Down Expand Up @@ -124,7 +127,8 @@ class mock_throwing_model : public stan::model::prob_grad {
}
}

void get_dims(std::vector<std::vector<size_t> >& dimss__) const {
void get_dims(std::vector<std::vector<size_t> >& dimss__,
bool include_tparams = true, bool include_gqs = true) const {
dimss__.resize(0);
std::vector<size_t> scalar_dim;
dimss__.push_back(scalar_dim);
Expand All @@ -140,7 +144,9 @@ class mock_throwing_model : public stan::model::prob_grad {
param_names__.push_back("c");
}

void get_param_names(std::vector<std::string>& names) const {
void get_param_names(std::vector<std::string>& names,
bool include_tparams = true,
bool include_gqs = true) const {
constrained_param_names(names);
}

Expand Down
14 changes: 10 additions & 4 deletions src/test/unit/variational/stochastic_gradient_ascent_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ class mock_model : public stan::model::prob_grad {
}
}

void get_dims(std::vector<std::vector<size_t> >& dimss__) const {
void get_dims(std::vector<std::vector<size_t> >& dimss__,
bool include_tparams = true, bool include_gqs = true) const {
dimss__.resize(0);
std::vector<size_t> scalar_dim;
dimss__.push_back(scalar_dim);
Expand All @@ -60,7 +61,9 @@ class mock_model : public stan::model::prob_grad {
param_names__.push_back("c");
}

void get_param_names(std::vector<std::string>& names) const {
void get_param_names(std::vector<std::string>& names,
bool include_tparams = true,
bool include_gqs = true) const {
constrained_param_names(names);
}

Expand Down Expand Up @@ -125,7 +128,8 @@ class mock_throwing_model : public stan::model::prob_grad {
}
}

void get_dims(std::vector<std::vector<size_t> >& dimss__) const {
void get_dims(std::vector<std::vector<size_t> >& dimss__,
bool include_tparams = true, bool include_gqs = true) const {
dimss__.resize(0);
std::vector<size_t> scalar_dim;
dimss__.push_back(scalar_dim);
Expand All @@ -141,7 +145,9 @@ class mock_throwing_model : public stan::model::prob_grad {
param_names__.push_back("c");
}

void get_param_names(std::vector<std::string>& names) const {
void get_param_names(std::vector<std::string>& names,
bool include_tparams = true,
bool include_gqs = true) const {
constrained_param_names(names);
}

Expand Down

0 comments on commit 5118264

Please sign in to comment.