Skip to content

Commit

Permalink
Initial CmdStan interface for low-rank ADVI
Browse files Browse the repository at this point in the history
  • Loading branch information
wjn0 committed Mar 17, 2021
1 parent a67347b commit 277b370
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/cmdstan/arguments/arg_variational.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <cmdstan/arguments/arg_variational_iter.hpp>
#include <cmdstan/arguments/arg_variational_num_samples.hpp>
#include <cmdstan/arguments/arg_variational_output_samples.hpp>
#include <cmdstan/arguments/arg_variational_rank.hpp>
#include <cmdstan/arguments/categorical_argument.hpp>
#include <stan/services/experimental/advi/defaults.hpp>

Expand All @@ -27,6 +28,7 @@ class arg_variational : public categorical_argument {

_subarguments.push_back(new arg_variational_algo());
_subarguments.push_back(new arg_variational_iter());
_subarguments.push_back(new arg_variational_rank());
_subarguments.push_back(new arg_variational_num_samples(
"grad_samples", gradient_samples::description().c_str(),
gradient_samples::default_value()));
Expand Down
2 changes: 2 additions & 0 deletions src/cmdstan/arguments/arg_variational_algo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <cmdstan/arguments/list_argument.hpp>
#include <cmdstan/arguments/arg_variational_fullrank.hpp>
#include <cmdstan/arguments/arg_variational_meanfield.hpp>
#include <cmdstan/arguments/arg_variational_lowrank.hpp>

namespace cmdstan {

Expand All @@ -15,6 +16,7 @@ class arg_variational_algo : public list_argument {

_values.push_back(new arg_variational_meanfield());
_values.push_back(new arg_variational_fullrank());
_values.push_back(new arg_variational_lowrank());

_default_cursor = 0;
_cursor = _default_cursor;
Expand Down
16 changes: 16 additions & 0 deletions src/cmdstan/arguments/arg_variational_lowrank.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#ifndef CMDSTAN_ARGUMENTS_VARIATIONAL_LOWRANK_HPP
#define CMDSTAN_ARGUMENTS_VARIATIONAL_LOWRANK_HPP

#include <cmdstan/arguments/categorical_argument.hpp>

namespace cmdstan {

class arg_variational_lowrank : public categorical_argument {
public:
arg_variational_lowrank() {
_name = "lowrank";
_description = "low-rank covariance";
}
};
} // namespace cmdstan
#endif
28 changes: 28 additions & 0 deletions src/cmdstan/arguments/arg_variational_rank.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#ifndef CMDSTAN_ARGUMENTS_VARIATIONAL_RANK_HPP
#define CMDSTAN_ARGUMENTS_VARIATIONAL_RANK_HPP

#include <cmdstan/arguments/singleton_argument.hpp>
#include <boost/lexical_cast.hpp>
#include <string>

namespace cmdstan {

using stan::services::experimental::advi::rank;

class arg_variational_rank : public int_argument {
public:
arg_variational_rank() : int_argument() {
_name = "rank";
_description = rank::description();
_validity = "0 <= rank";
_default = boost::lexical_cast<std::string>(rank::default_value());
_default_value = rank::default_value();
_constrained = true;
_good_value = rank::default_value();
_bad_value = -1.0;
_value = rank::default_value();
}
bool is_valid(int value) { return value >= 0; }
};
} // namespace cmdstan
#endif
10 changes: 10 additions & 0 deletions src/cmdstan/command.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <stan/services/diagnose/diagnose.hpp>
#include <stan/services/experimental/advi/fullrank.hpp>
#include <stan/services/experimental/advi/meanfield.hpp>
#include <stan/services/experimental/advi/lowrank.hpp>
#include <stan/services/optimize/bfgs.hpp>
#include <stan/services/optimize/lbfgs.hpp>
#include <stan/services/optimize/newton.hpp>
Expand Down Expand Up @@ -877,6 +878,15 @@ int command(int argc, const char *argv[]) {
elbo_samples, max_iterations, tol_rel_obj, eta, adapt_engaged,
adapt_iterations, eval_elbo, output_samples, interrupt, logger,
init_writer, sample_writer, diagnostic_writer);
} else if (algo->value() == "lowrank") {
int rank
= dynamic_cast<int_argument *>(
parser.arg("method")->arg("variational")->arg("rank"))->value();
return_code = stan::services::experimental::advi::lowrank(
model, *init_context, random_seed, id, init_radius, grad_samples,
elbo_samples, max_iterations, tol_rel_obj, rank, eta, adapt_engaged,
adapt_iterations, eval_elbo, output_samples, interrupt, logger,
init_writer, sample_writer, diagnostic_writer);
}
}
stan::math::profile_map &profile_data = get_stan_profile_data();
Expand Down
11 changes: 11 additions & 0 deletions src/test/interface/variational_output_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,14 @@ TEST_F(CmdStan, variational_fullrank) {
ASSERT_EQ(1, chains.num_chains());
ASSERT_EQ(1001, chains.num_samples());
}

TEST_F(CmdStan, variational_lowrank) {
run_command_output out
= run_command(base_command + " variational algorithm=lowrank rank=1");

ASSERT_EQ(0, out.err_code);

stan::mcmc::chains<> chains = parse_output_file();
ASSERT_EQ(1, chains.num_chains());
ASSERT_EQ(1001, chains.num_samples());
}

0 comments on commit 277b370

Please sign in to comment.