Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Allow NUTS to do eager evaluation on forward and backward trajectory in parallel #3103

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from
19 changes: 13 additions & 6 deletions src/stan/mcmc/hmc/nuts/adapt_diag_e_nuts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,25 @@ namespace mcmc {
* with a Gaussian-Euclidean disintegration and adaptive
* diagonal metric and adaptive step size
*/
template <class Model, class BaseRNG>
class adapt_diag_e_nuts : public diag_e_nuts<Model, BaseRNG>,
template <class Model, class BaseRNG, bool ParallelBase = false>
class adapt_diag_e_nuts : public diag_e_nuts<Model, BaseRNG, ParallelBase>,
public stepsize_var_adapter {
public:
template <bool ParallelBase_ = ParallelBase,
std::enable_if_t<!ParallelBase_>* = nullptr>
adapt_diag_e_nuts(const Model& model, BaseRNG& rng)
: diag_e_nuts<Model, BaseRNG>(model, rng),
: diag_e_nuts<Model, BaseRNG, ParallelBase>(model, rng),
stepsize_var_adapter(model.num_params_r()) {}

~adapt_diag_e_nuts() {}
template <bool ParallelBase_ = ParallelBase,
std::enable_if_t<ParallelBase_>* = nullptr>
adapt_diag_e_nuts(const Model& model, std::vector<BaseRNG>& thread_rngs)
: diag_e_nuts<Model, BaseRNG, ParallelBase>(model, thread_rngs),
stepsize_var_adapter(model.num_params_r()) {}

sample transition(sample& init_sample, callbacks::logger& logger) {
sample s = diag_e_nuts<Model, BaseRNG>::transition(init_sample, logger);
inline sample transition(sample& init_sample, callbacks::logger& logger) {
sample s = diag_e_nuts<Model, BaseRNG, ParallelBase>::transition(
init_sample, logger);

if (this->adapt_flag_) {
this->stepsize_adaptation_.learn_stepsize(this->nom_epsilon_,
Expand Down
Loading