Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel Adaptive Nuts #3033

Merged
merged 67 commits into from
Jul 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
5edf4e2
adds file_stream_writer and run_adaptive_sampler method for parallel
SteveBronder Mar 20, 2021
faa5eb0
fix grammar error for parallel run_adaptive_sampler
SteveBronder Mar 21, 2021
751403f
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Mar 21, 2021
a08ba35
adds test for parallel adaptive
SteveBronder Mar 21, 2021
a4c9679
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Mar 21, 2021
04836ef
include mutex
SteveBronder Mar 21, 2021
0608601
make stream_writer not final
SteveBronder Mar 21, 2021
3ebbc5c
init threadpool
SteveBronder Mar 21, 2021
a3bf12b
update generate_transitions and cleanup run_adaptive_sampler
SteveBronder Mar 24, 2021
9056101
Merge commit '7eeaf3c58fdd1c40aa62ba7158106529f2fd9563' into HEAD
yashikno Mar 24, 2021
5c7b0bd
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Mar 24, 2021
117274f
start diag_e_adapt parallel
SteveBronder Mar 24, 2021
99d64c1
adds tests for parallel adapter
SteveBronder Mar 24, 2021
35f88a7
Merge remote-tracking branch 'origin/develop' into feature/parallel-a…
SteveBronder Mar 24, 2021
ebeb1f9
remove statics from softmax metric and cleanup tests
SteveBronder Mar 24, 2021
142a94a
update to feature/parallel-adapt
SteveBronder Mar 24, 2021
558eeb0
use normal diag_e_adapt if n_chain == 0
SteveBronder Mar 24, 2021
5055e21
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Mar 25, 2021
34114be
use parallel loop around run_adaptive_sampler
SteveBronder Mar 25, 2021
8458782
remove run_adaptive_sampler parallel version
SteveBronder Mar 25, 2021
38e168d
update to remove parallel run_adaptive_sampler
SteveBronder Mar 25, 2021
e72409b
Merge commit 'f3bf21bc20271ebb9f7c9613bdb17c16d5cc0c1b' into HEAD
yashikno Mar 25, 2021
a2daea0
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Mar 25, 2021
427769c
update get_context to return back the input
SteveBronder Mar 25, 2021
f654f37
update math
SteveBronder Mar 29, 2021
df1542e
update with docs
SteveBronder Mar 30, 2021
0a5075b
update stan math
SteveBronder Mar 30, 2021
ce2b7a1
update file streamer and adapt
SteveBronder Mar 30, 2021
e2dfbaf
update stan math
SteveBronder Mar 30, 2021
fce9734
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Mar 30, 2021
059a176
adds dense_e_adapt for parallel chains
SteveBronder Mar 31, 2021
d99daa0
Merge branch 'feature/parallel-nuts' of github.com:stan-dev/stan into…
SteveBronder Mar 31, 2021
7dd37b6
Merge commit '55f6e5794c1a5a76f8cf23bc154dd254ede3258e' into HEAD
yashikno Mar 31, 2021
147fba5
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Mar 31, 2021
ddfcf86
Merge remote-tracking branch 'origin/develop' into feature/parallel-nuts
SteveBronder Apr 13, 2021
a2666e4
Merge remote-tracking branch 'origin/develop' into feature/parallel-nuts
SteveBronder Apr 22, 2021
323adb9
update to reflect parallel chain design doc
SteveBronder Apr 22, 2021
7621468
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Apr 22, 2021
f69b139
Merge remote-tracking branch 'origin/develop' into feature/parallel-nuts
SteveBronder May 12, 2021
7146784
make changes so API matches up with design doc
SteveBronder May 12, 2021
e1e37a4
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot May 12, 2021
ea9a556
update docs
SteveBronder May 12, 2021
f3adfb6
Merge branch 'feature/parallel-nuts' of github.com:stan-dev/stan into…
SteveBronder May 12, 2021
57031ac
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot May 12, 2021
ada9dcb
Merge remote-tracking branch 'origin/develop' into feature/parallel-nuts
SteveBronder Jun 15, 2021
bd76742
update so tbb is only used when STAN_THREADS is defined
SteveBronder Jun 15, 2021
851ba01
update to use simple partitioner
SteveBronder Jun 15, 2021
429e7de
fix chain_id default so all is 1
SteveBronder Jun 15, 2021
c66a2fd
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Jun 15, 2021
45d64b3
Merge remote-tracking branch 'origin/develop' into feature/parallel-nuts
SteveBronder Jun 28, 2021
724f887
update create_rng() to follow the design doc
SteveBronder Jun 28, 2021
203d6f3
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Jun 28, 2021
fa22037
adds template to unique_stream_writer to set the stream type, adds te…
SteveBronder Jun 29, 2021
34baa29
Merge branch 'feature/parallel-nuts' of github.com:stan-dev/stan into…
SteveBronder Jun 29, 2021
d0e501b
Merge remote-tracking branch 'origin/develop' into feature/parallel-nuts
SteveBronder Jun 29, 2021
a115c82
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Jun 29, 2021
07b2230
Merge remote-tracking branch 'origin/develop' into feature/parallel-nuts
SteveBronder Jun 30, 2021
51c4472
update dense and diag adapt for math changes
SteveBronder Jun 30, 2021
9866ef0
Merge remote-tracking branch 'origin/develop' into feature/parallel-nuts
SteveBronder Jul 1, 2021
1ad42b4
cleanup docs and change printout of chain number to take the initial …
SteveBronder Jul 1, 2021
c50cf1a
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Jul 1, 2021
c09ca13
update printing logic
SteveBronder Jul 2, 2021
91d51e1
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Jul 2, 2021
569df85
update to remove +1 from chain number
SteveBronder Jul 7, 2021
bbf4d37
Merge remote-tracking branch 'origin/develop' into feature/parallel-nuts
SteveBronder Jul 19, 2021
4f82e24
Make sure to use num_chains instead of n_chain everywhere, fix docs, …
SteveBronder Jul 19, 2021
a11a48a
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Jul 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/stan/callbacks/stream_logger.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace callbacks {
* <code>logger</code> that writes messages to separate
* std::stringstream outputs.
*/
class stream_logger : public logger {
class stream_logger final : public logger {
private:
std::ostream& debug_;
std::ostream& info_;
Expand Down
2 changes: 1 addition & 1 deletion src/stan/callbacks/tee_writer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace callbacks {
* For any call to this writer, it will tee the call to both writers
* provided in the constructor.
*/
class tee_writer : public writer {
class tee_writer final : public writer {
public:
/**
* Constructor accepting two writers.
Expand Down
127 changes: 127 additions & 0 deletions src/stan/callbacks/unique_stream_writer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#ifndef STAN_CALLBACKS_UNIQUE_STREAM_WRITER_HPP
#define STAN_CALLBACKS_UNIQUE_STREAM_WRITER_HPP

#include <stan/callbacks/writer.hpp>
#include <ostream>
#include <vector>
#include <string>

namespace stan {
namespace callbacks {

/**
* <code>unique_stream_writer</code> is an implementation
* of <code>writer</code> that holds a unique pointer to the stream it is
* writing to.
* @tparam Stream A type with with a valid `operator<<(std::string)`
*/
template <typename Stream>
class unique_stream_writer final : public writer {
public:
/**
* Constructs a unique stream writer with an output stream
* and an optional prefix for comments.
*
* @param[in, out] A unique pointer to a type inheriting from `std::ostream`
* @param[in] comment_prefix string to stream before each comment line.
* Default is "".
*/
explicit unique_stream_writer(std::unique_ptr<Stream>&& output,
const std::string& comment_prefix = "")
: output_(std::move(output)), comment_prefix_(comment_prefix) {}

unique_stream_writer();
unique_stream_writer(unique_stream_writer& other) = delete;
unique_stream_writer(unique_stream_writer&& other)
: output_(std::move(other.output_)),
comment_prefix_(std::move(other.comment_prefix_)) {}
/**
* Virtual destructor
*/
virtual ~unique_stream_writer() {}

/**
* Writes a set of names on a single line in csv format followed
* by a newline.
*
* Note: the names are not escaped.
*
* @param[in] names Names in a std::vector
*/
void operator()(const std::vector<std::string>& names) {
write_vector(names);
}
/**
* Get the underlying stream
*/
auto& get_stream() { return *output_; }

/**
* Writes a set of values in csv format followed by a newline.
*
* Note: the precision of the output is determined by the settings
* of the stream on construction.
*
* @param[in] state Values in a std::vector
*/
void operator()(const std::vector<double>& state) { write_vector(state); }

/**
* Writes the comment_prefix to the stream followed by a newline.
*/
void operator()() {
std::stringstream streamer;
streamer << comment_prefix_ << std::endl;
*output_ << streamer.str();
}

/**
* Writes the comment_prefix then the message followed by a newline.
*
* @param[in] message A string
*/
void operator()(const std::string& message) {
std::stringstream streamer;
streamer << comment_prefix_ << message << std::endl;
*output_ << streamer.str();
}

private:
/**
* Output stream
*/
std::unique_ptr<Stream> output_;

/**
* Comment prefix to use when printing comments: strings and blank lines
*/
std::string comment_prefix_;

/**
* Writes a set of values in csv format followed by a newline.
*
* Note: the precision of the output is determined by the settings
* of the stream on construction.
*
* @param[in] v Values in a std::vector
*/
template <class T>
void write_vector(const std::vector<T>& v) {
if (v.empty())
return;
using const_iter = typename std::vector<T>::const_iterator;
const_iter last = v.end();
--last;
std::stringstream streamer;
for (const_iter it = v.begin(); it != last; ++it) {
streamer << *it << ",";
}
streamer << v.back() << std::endl;
*output_ << streamer.str();
}
};

} // namespace callbacks
} // namespace stan

#endif
12 changes: 6 additions & 6 deletions src/stan/mcmc/hmc/hamiltonians/softabs_metric.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,26 +174,26 @@ class softabs_metric : public base_hamiltonian<Model, softabs_point, BaseRNG> {

// Threshold below which a power series
// approximation of the softabs function is used
static double lower_softabs_thresh;
static constexpr double lower_softabs_thresh = 1e-4;

// Threshold above which an asymptotic
// approximation of the softabs function is used
static double upper_softabs_thresh;
static constexpr double upper_softabs_thresh = 18;

// Threshold below which an exact derivative is
// used in the Jacobian calculation instead of
// finite differencing
static double jacobian_thresh;
static constexpr double jacobian_thresh = 1e-10;
};

template <class Model, class BaseRNG>
double softabs_metric<Model, BaseRNG>::lower_softabs_thresh = 1e-4;
constexpr double softabs_metric<Model, BaseRNG>::lower_softabs_thresh;

template <class Model, class BaseRNG>
double softabs_metric<Model, BaseRNG>::upper_softabs_thresh = 18;
constexpr double softabs_metric<Model, BaseRNG>::upper_softabs_thresh;

template <class Model, class BaseRNG>
double softabs_metric<Model, BaseRNG>::jacobian_thresh = 1e-10;
constexpr double softabs_metric<Model, BaseRNG>::jacobian_thresh;
} // namespace mcmc
} // namespace stan
#endif
204 changes: 203 additions & 1 deletion src/stan/services/sample/hmc_nuts_dense_e_adapt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ int hmc_nuts_dense_e_adapt(
callbacks::writer& sample_writer, callbacks::writer& diagnostic_writer) {
boost::ecuyer1988 rng = util::create_rng(random_seed, chain);

std::vector<int> disc_vector;
std::vector<double> cont_vector = util::initialize(
model, init, rng, init_radius, true, logger, init_writer);

Expand Down Expand Up @@ -156,6 +155,209 @@ int hmc_nuts_dense_e_adapt(
interrupt, logger, init_writer, sample_writer, diagnostic_writer);
}

/**
* Runs multiple chains of NUTS with adaptation using dense Euclidean metric
* with a pre-specified Euclidean metric.
*
* @tparam Model Model class
* @tparam InitContextPtr A pointer with underlying type derived from
`stan::io::var_context`
* @tparam InitInvContextPtr A pointer with underlying type derived from
`stan::io::var_context`
* @tparam SamplerWriter A type derived from `stan::callbacks::writer`
* @tparam DiagnosticWriter A type derived from `stan::callbacks::writer`
* @tparam InitWriter A type derived from `stan::callbacks::writer`
* @param[in] model Input model to test (with data already instantiated)
* @param[in] num_chains The number of chains to run in parallel. `init`,
* `init_inv_metric`, `init_writer`, `sample_writer`, and `diagnostic_writer`
must
* be the same length as this value.
* @param[in] init An std vector of init var contexts for initialization of each
* chain.
* @param[in] init_inv_metric An std vector of var contexts exposing an initial
* diagonal inverse Euclidean metric for each chain (must be positive definite)
* @param[in] random_seed random seed for the random number generator
* @param[in] init_chain_id first chain id. The pseudo random number generator
* will advance by for each chain by an integer sequence from `init_chain_id` to
* `init_chain_id+num_chains-1`
* @param[in] init_radius radius to initialize
* @param[in] num_warmup Number of warmup samples
* @param[in] num_samples Number of samples
* @param[in] num_thin Number to thin the samples
* @param[in] save_warmup Indicates whether to save the warmup iterations
* @param[in] refresh Controls the output
* @param[in] stepsize initial stepsize for discrete evolution
* @param[in] stepsize_jitter uniform random jitter of stepsize
* @param[in] max_depth Maximum tree depth
* @param[in] delta adaptation target acceptance statistic
* @param[in] gamma adaptation regularization scale
* @param[in] kappa adaptation relaxation exponent
* @param[in] t0 adaptation iteration offset
* @param[in] init_buffer width of initial fast adaptation interval
* @param[in] term_buffer width of final fast adaptation interval
* @param[in] window initial width of slow adaptation interval
* @param[in,out] interrupt Callback for interrupts
* @param[in,out] logger Logger for messages
* @param[in,out] init_writer std vector of Writer callbacks for unconstrained
inits of each chain.
* @param[in,out] sample_writer std vector of Writers for draws of each chain.
* @param[in,out] diagnostic_writer std vector of Writers for diagnostic
* information of each chain.
* @return error_codes::OK if successful
*/
template <class Model, typename InitContextPtr, typename InitInvContextPtr,
typename InitWriter, typename SampleWriter, typename DiagnosticWriter>
int hmc_nuts_dense_e_adapt(
Model& model, size_t num_chains, const std::vector<InitContextPtr>& init,
const std::vector<InitInvContextPtr>& init_inv_metric,
unsigned int random_seed, unsigned int init_chain_id, double init_radius,
int num_warmup, int num_samples, int num_thin, bool save_warmup,
int refresh, double stepsize, double stepsize_jitter, int max_depth,
double delta, double gamma, double kappa, double t0,
unsigned int init_buffer, unsigned int term_buffer, unsigned int window,
callbacks::interrupt& interrupt, callbacks::logger& logger,
std::vector<InitWriter>& init_writer,
std::vector<SampleWriter>& sample_writer,
std::vector<DiagnosticWriter>& diagnostic_writer) {
if (num_chains == 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the code duplication here? Why not make the 1-core function defer to the multi-chain function by wrapping things into 1-sized vectors?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong opinion about this, but tbh I'd like to just make sure we are 100% cool with backwards compatibility by always deferring to the original API version if chains=1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Na... let's just ensure we have good tests for that and instead make the old code call the new one with chains=1.

One issue you have when doing that is that you get references to inits and other things passed into the function. These then need to be wrapped into a size 1 vector of refs. Since you went with templates one can this use std::ref to make this work if I am not mistaken. I'd really like to avoid code duplication which is a source for errors in the future.

return hmc_nuts_dense_e_adapt(
model, *init[0], *init_inv_metric[0], random_seed, init_chain_id,
init_radius, num_warmup, num_samples, num_thin, save_warmup, refresh,
stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0,
init_buffer, term_buffer, window, interrupt, logger, init_writer[0],
sample_writer[0], diagnostic_writer[0]);
}
using sample_t = stan::mcmc::adapt_dense_e_nuts<Model, boost::ecuyer1988>;
std::vector<boost::ecuyer1988> rngs;
rngs.reserve(num_chains);
std::vector<std::vector<double>> cont_vectors;
cont_vectors.reserve(num_chains);
std::vector<sample_t> samplers;
samplers.reserve(num_chains);
try {
for (int i = 0; i < num_chains; ++i) {
rngs.emplace_back(util::create_rng(random_seed, init_chain_id + i));
cont_vectors.emplace_back(util::initialize(
model, *init[i], rngs[i], init_radius, true, logger, init_writer[i]));
Eigen::MatrixXd inv_metric = util::read_dense_inv_metric(
*init_inv_metric[i], model.num_params_r(), logger);
util::validate_dense_inv_metric(inv_metric, logger);

samplers.emplace_back(model, rngs[i]);
samplers[i].set_metric(inv_metric);
samplers[i].set_nominal_stepsize(stepsize);
samplers[i].set_stepsize_jitter(stepsize_jitter);
samplers[i].set_max_depth(max_depth);

samplers[i].get_stepsize_adaptation().set_mu(log(10 * stepsize));
samplers[i].get_stepsize_adaptation().set_delta(delta);
samplers[i].get_stepsize_adaptation().set_gamma(gamma);
samplers[i].get_stepsize_adaptation().set_kappa(kappa);
samplers[i].get_stepsize_adaptation().set_t0(t0);
samplers[i].set_window_params(num_warmup, init_buffer, term_buffer,
window, logger);
}
} catch (const std::domain_error& e) {
return error_codes::CONFIG;
}
tbb::parallel_for(tbb::blocked_range<size_t>(0, num_chains, 1),
[num_warmup, num_samples, num_thin, refresh, save_warmup,
num_chains, init_chain_id, &samplers, &model, &rngs,
&interrupt, &logger, &sample_writer, &cont_vectors,
&diagnostic_writer](const tbb::blocked_range<size_t>& r) {
for (size_t i = r.begin(); i != r.end(); ++i) {
util::run_adaptive_sampler(
samplers[i], model, cont_vectors[i], num_warmup,
num_samples, num_thin, refresh, save_warmup,
rngs[i], interrupt, logger, sample_writer[i],
diagnostic_writer[i], init_chain_id + i,
num_chains);
}
},
tbb::simple_partitioner());
return error_codes::OK;
}

/**
* Runs multiple chains of NUTS with adaptation using dense Euclidean metric,
* with identity matrix as initial inv_metric.
*
* @tparam Model Model class
* @tparam InitContextPtr A pointer with underlying type derived from
* `stan::io::var_context`
* @tparam InitWriter A type derived from `stan::callbacks::writer`
* @tparam SamplerWriter A type derived from `stan::callbacks::writer`
* @tparam DiagnosticWriter A type derived from `stan::callbacks::writer`
* @param[in] model Input model to test (with data already instantiated)
* @param[in] num_chains The number of chains to run in parallel. `init`,
* `init_writer`, `sample_writer`, and `diagnostic_writer` must be the same
* length as this value.
* @param[in] init An std vector of init var contexts for initialization of each
* chain.
* @param[in] random_seed random seed for the random number generator
* @param[in] init_chain_id first chain id. The pseudo random number generator
* will advance by for each chain by an integer sequence from `init_chain_id` to
* `init_chain_id+num_chains-1`
* @param[in] init_radius radius to initialize
* @param[in] num_warmup Number of warmup samples
* @param[in] num_samples Number of samples
* @param[in] num_thin Number to thin the samples
* @param[in] save_warmup Indicates whether to save the warmup iterations
* @param[in] refresh Controls the output
* @param[in] stepsize initial stepsize for discrete evolution
* @param[in] stepsize_jitter uniform random jitter of stepsize
* @param[in] max_depth Maximum tree depth
* @param[in] delta adaptation target acceptance statistic
* @param[in] gamma adaptation regularization scale
* @param[in] kappa adaptation relaxation exponent
* @param[in] t0 adaptation iteration offset
* @param[in] init_buffer width of initial fast adaptation interval
* @param[in] term_buffer width of final fast adaptation interval
* @param[in] window initial width of slow adaptation interval
* @param[in,out] interrupt Callback for interrupts
* @param[in,out] logger Logger for messages
* @param[in,out] init_writer std vector of Writer callbacks for unconstrained
* inits of each chain.
* @param[in,out] sample_writer std vector of Writers for draws of each chain.
* @param[in,out] diagnostic_writer std vector of Writers for diagnostic
* information of each chain.
* @return error_codes::OK if successful
*/
template <class Model, typename InitContextPtr, typename InitWriter,
typename SampleWriter, typename DiagnosticWriter>
int hmc_nuts_dense_e_adapt(
Model& model, size_t num_chains, const std::vector<InitContextPtr>& init,
unsigned int random_seed, unsigned int init_chain_id, double init_radius,
int num_warmup, int num_samples, int num_thin, bool save_warmup,
int refresh, double stepsize, double stepsize_jitter, int max_depth,
double delta, double gamma, double kappa, double t0,
unsigned int init_buffer, unsigned int term_buffer, unsigned int window,
callbacks::interrupt& interrupt, callbacks::logger& logger,
std::vector<InitWriter>& init_writer,
std::vector<SampleWriter>& sample_writer,
std::vector<DiagnosticWriter>& diagnostic_writer) {
if (num_chains == 1) {
return hmc_nuts_dense_e_adapt(
model, *init[0], random_seed, init_chain_id, init_radius, num_warmup,
num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter,
max_depth, delta, gamma, kappa, t0, init_buffer, term_buffer, window,
interrupt, logger, init_writer[0], sample_writer[0],
diagnostic_writer[0]);
}
std::vector<std::unique_ptr<stan::io::dump>> unit_e_metrics;
unit_e_metrics.reserve(num_chains);
for (size_t i = 0; i < num_chains; ++i) {
unit_e_metrics.emplace_back(std::make_unique<stan::io::dump>(
util::create_unit_e_dense_inv_metric(model.num_params_r())));
}
return hmc_nuts_dense_e_adapt(
model, num_chains, init, unit_e_metrics, random_seed, init_chain_id,
init_radius, num_warmup, num_samples, num_thin, save_warmup, refresh,
stepsize, stepsize_jitter, max_depth, delta, gamma, kappa, t0,
init_buffer, term_buffer, window, interrupt, logger, init_writer,
sample_writer, diagnostic_writer);
}

} // namespace sample
} // namespace services
} // namespace stan
Expand Down
Loading