Skip to content

Commit

Permalink
Don't require passing proposal to adapter initializers
Browse files Browse the repository at this point in the history
  • Loading branch information
matt-graham committed Oct 23, 2024
1 parent de17233 commit fa8bf5f
Show file tree
Hide file tree
Showing 9 changed files with 95 additions and 89 deletions.
52 changes: 27 additions & 25 deletions R/adaptation.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
#' @param kappa Decay rate exponent in `[0.5, 1]` for adaptation learning rate.
#'
#' @return List of functions with entries
#' * `initialize`, a function for initializing adapter state at beginning of
#' chain,
#' * `initialize`, a function for initializing adapter state and proposal
#' parameters at beginning of chain,
#' * `update` a function for updating adapter state and proposal parameters on
#' each chain iteration,
#' * `finalize` a function for performing any final updates to adapter state and
Expand All @@ -27,24 +27,22 @@
#' grad_log_density = function(x) -x
#' )
#' proposal <- barker_proposal(target_distribution)
#' adapter <- scale_adapter(
#' proposal,
#' initial_scale = 1., target_accept_prob = 0.4
#' )
#' adapter <- scale_adapter(initial_scale = 1., target_accept_prob = 0.4)
#' adapter$initialize(proposal, chain_state(c(0, 0)))
scale_adapter <- function(
proposal, initial_scale = NULL, target_accept_prob = NULL, kappa = 0.6) {
initial_scale = NULL, target_accept_prob = NULL, kappa = 0.6) {
log_scale <- NULL
if (is.null(target_accept_prob)) {
target_accept_prob <- proposal$default_target_accept_prob()
}
initialize <- function(initial_state) {
initialize <- function(proposal, initial_state) {
if (is.null(initial_scale)) {
initial_scale <- proposal$default_initial_scale(initial_state$dimension())
}
log_scale <<- log(initial_scale)
proposal$update(scale = initial_scale)
}
update <- function(sample_index, state_and_statistics) {
update <- function(proposal, sample_index, state_and_statistics) {
if (is.null(target_accept_prob)) {
target_accept_prob <- proposal$default_target_accept_prob()
}
gamma <- sample_index^(-kappa)
accept_prob <- state_and_statistics$statistics$accept_prob
log_scale <<- log_scale + gamma * (accept_prob - target_accept_prob)
Expand All @@ -53,7 +51,7 @@ scale_adapter <- function(
list(
initialize = initialize,
update = update,
finalize = function() {},
finalize = NULL,
state = function() list(log_scale = log_scale)
)
}
Expand All @@ -74,15 +72,16 @@ scale_adapter <- function(
#' grad_log_density = function(x) -x
#' )
#' proposal <- barker_proposal(target_distribution)
#' adapter <- variance_adapter(proposal)
variance_adapter <- function(proposal, kappa = 0.6) {
#' adapter <- variance_adapter()
#' adapter$initialize(proposal, chain_state(c(0, 0)))
variance_adapter <- function(kappa = 0.6) {
mean_estimate <- NULL
variance_estimate <- NULL
initialize <- function(initial_state) {
initialize <- function(proposal, initial_state) {
mean_estimate <<- initial_state$position()
variance_estimate <<- rep(1., initial_state$dimension())
}
update <- function(sample_index, state_and_statistics) {
update <- function(proposal, sample_index, state_and_statistics) {
gamma <- sample_index^(-kappa)
position <- state_and_statistics$state$position()
mean_estimate <<- mean_estimate + gamma * (position - mean_estimate)
Expand Down Expand Up @@ -124,20 +123,23 @@ variance_adapter <- function(proposal, kappa = 0.6) {
#' grad_log_density = function(x) -x
#' )
#' proposal <- barker_proposal(target_distribution)
#' adapter <- robust_shape_adapter(
#' proposal,
#' initial_scale = 1.,
#' target_accept_prob = 0.4
#' )
#' adapter <- robust_shape_adapter(initial_scale = 1., target_accept_prob = 0.4)
#' adapter$initialize(proposal, chain_state(c(0, 0)))
robust_shape_adapter <- function(
proposal, initial_scale, target_accept_prob = 0.4, kappa = 0.6) {
initial_scale = NULL, target_accept_prob = NULL, kappa = 0.6) {
rlang::check_installed("ramcmc", reason = "to use this function")
shape <- NULL
initialize <- function(initial_state) {
initialize <- function(proposal, initial_state) {
if (is.null(initial_scale)) {
initial_scale <- proposal$default_initial_scale(initial_state$dimension())
}
shape <<- initial_scale * diag(initial_state$dimension())
proposal$update(shape = shape)
}
update <- function(sample_index, state_and_statistics) {
update <- function(proposal, sample_index, state_and_statistics) {
if (is.null(target_accept_prob)) {
target_accept_prob <- proposal$default_target_accept_prob()

Check warning on line 141 in R/adaptation.R

View check run for this annotation

Codecov / codecov/patch

R/adaptation.R#L141

Added line #L141 was not covered by tests
}
momentum <- state_and_statistics$proposed_state$momentum()
accept_prob <- state_and_statistics$statistics$accept_prob
shape <<- ramcmc::adapt_S(
Expand Down
12 changes: 6 additions & 6 deletions R/chains.R
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ chain_loop <- function(
statistic_names) {
progress_bar <- get_progress_bar(use_progress_bar, n_iteration, stage_name)
for (adapter in adapters) {
adapter$initialize(state)
adapter$initialize(proposal, state)
}
if (record_traces_and_statistics) {
trace_names <- names(unlist(trace_function(state)))
Expand All @@ -176,25 +176,25 @@ chain_loop <- function(
traces <- NULL
statistics <- NULL
}
for (s in seq_len(n_iteration)) {
for (chain_iteration in seq_len(n_iteration)) {
state_and_statistics <- sample_metropolis_hastings(
state, target_distribution, proposal
)
for (adapter in adapters) {
adapter$update(s + 1, state_and_statistics)
adapter$update(proposal, chain_iteration + 1, state_and_statistics)
}
state <- state_and_statistics$state
if (record_traces_and_statistics) {
traces[s, ] <- unlist(trace_function(state))
traces[chain_iteration, ] <- unlist(trace_function(state))
adapter_states <- lapply(adapters, function(a) a$state())
statistics[s, ] <- unlist(
statistics[chain_iteration, ] <- unlist(
c(state_and_statistics$statistics, adapter_states)
)
}
if (!is.null(progress_bar)) progress_bar$tick()
}
for (adapter in adapters) {
if (!is.null(adapter$finalize)) adapter$finalize()
if (!is.null(adapter$finalize)) adapter$finalize(proposal)

Check warning on line 197 in R/chains.R

View check run for this annotation

Codecov / codecov/patch

R/chains.R#L197

Added line #L197 was not covered by tests
}
list(final_state = state, traces = traces, statistics = statistics)
}
Expand Down
2 changes: 1 addition & 1 deletion README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ results <- sample_chain(
initial_state = rnorm(dimension),
n_warm_up_iteration = 1000,
n_main_iteration = 1000,
adapters = list(scale_adapter(proposal), variance_adapter(proposal))
adapters = list(scale_adapter(), variance_adapter())
)
mean_accept_prob <- mean(results$statistics[, "accept_prob"])
adapted_shape <- proposal$parameters()$shape
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ results <- sample_chain(
initial_state = rnorm(dimension),
n_warm_up_iteration = 1000,
n_main_iteration = 1000,
adapters = list(scale_adapter(proposal), variance_adapter(proposal))
adapters = list(scale_adapter(), variance_adapter())
)
mean_accept_prob <- mean(results$statistics[, "accept_prob"])
adapted_shape <- proposal$parameters()$shape
Expand Down
19 changes: 6 additions & 13 deletions man/robust_shape_adapter.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 8 additions & 15 deletions man/scale_adapter.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 7 additions & 6 deletions man/variance_adapter.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit fa8bf5f

Please sign in to comment.