Skip to content

Commit

Permalink
fix issue #1111
Browse files Browse the repository at this point in the history
  • Loading branch information
paul-buerkner committed Mar 10, 2021
1 parent 7bbede1 commit ce57769
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 17 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ evaluation of most model terms.
Martin Modrak. The behavior can be controlled via `file_refit`. (#1058)
* Allow for a finer tuning of informational messages printed in `brm`
via the `silent` argument. (#1076)
* Allow user-defined Stan variables (specified via `stanvars`) to be used
inside threaded log-likelihoods. (#1111)

### Other Changes

Expand Down
3 changes: 2 additions & 1 deletion R/make_stancode.R
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ make_stancode <- function(formula, data, family = gaussian(),
pll_args <- stan_clean_pll_args(
scode_predictor[[i]]$pll_args,
scode_ranef$pll_args,
scode_Xme$pll_args
scode_Xme$pll_args,
collapse_stanvars_pll_args(stanvars)
)
partial_log_lik <- paste0(
scode_predictor[[i]]$pll_def,
Expand Down
25 changes: 13 additions & 12 deletions R/stan-likelihood.R
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ stan_log_lik_skew_normal <- function(bterms, resp = "", mix = "",
sdist("skew_normal", p$mu, p$omega, p$alpha)
}

stan_log_lik_poisson <- function(bterms, resp = "", mix = "", threads = 1,
stan_log_lik_poisson <- function(bterms, resp = "", mix = "", threads = NULL,
...) {
if (use_glm_primitive(bterms)) {
p <- args_glm_primitive(bterms$dpars$mu, resp = resp, threads = threads)
Expand All @@ -503,7 +503,7 @@ stan_log_lik_poisson <- function(bterms, resp = "", mix = "", threads = 1,
out
}

stan_log_lik_negbinomial <- function(bterms, resp = "", mix = "", threads = 1,
stan_log_lik_negbinomial <- function(bterms, resp = "", mix = "", threads = NULL,
...) {
if (use_glm_primitive(bterms)) {
p <- args_glm_primitive(bterms$dpars$mu, resp = resp, threads = threads)
Expand All @@ -520,7 +520,7 @@ stan_log_lik_negbinomial <- function(bterms, resp = "", mix = "", threads = 1,
out
}

stan_log_lik_geometric <- function(bterms, resp = "", mix = "", threads = 1,
stan_log_lik_geometric <- function(bterms, resp = "", mix = "", threads = NULL,
...) {
if (use_glm_primitive(bterms)) {
p <- args_glm_primitive(bterms$dpars$mu, resp = resp, threads = threads)
Expand All @@ -537,7 +537,7 @@ stan_log_lik_geometric <- function(bterms, resp = "", mix = "", threads = 1,
}
}

stan_log_lik_binomial <- function(bterms, resp = "", mix = "", threads = 1,
stan_log_lik_binomial <- function(bterms, resp = "", mix = "", threads = NULL,
...) {
reqn <- stan_log_lik_adj(bterms) || nzchar(mix)
p <- stan_log_lik_dpars(bterms, reqn, resp, mix)
Expand All @@ -547,7 +547,7 @@ stan_log_lik_binomial <- function(bterms, resp = "", mix = "", threads = 1,
sdist(lpdf, p$trials, p$mu)
}

stan_log_lik_bernoulli <- function(bterms, resp = "", mix = "", threads = 1,
stan_log_lik_bernoulli <- function(bterms, resp = "", mix = "", threads = NULL,
...) {
if (use_glm_primitive(bterms)) {
p <- args_glm_primitive(bterms$dpars$mu, resp = resp, threads = threads)
Expand Down Expand Up @@ -619,7 +619,7 @@ stan_log_lik_inverse.gaussian <- function(bterms, resp = "", mix = "", ...) {
sdist(lpdf, p$mu, p$shape)
}

stan_log_lik_wiener <- function(bterms, resp = "", mix = "", threads = 1,
stan_log_lik_wiener <- function(bterms, resp = "", mix = "", threads = NULL,
...) {
p <- stan_log_lik_dpars(bterms, TRUE, resp, mix)
n <- stan_nn(threads)
Expand All @@ -645,7 +645,7 @@ stan_log_lik_von_mises <- function(bterms, resp = "", mix = "", ...) {
sdist(lpdf, p$mu, p$kappa)
}

stan_log_lik_cox <- function(bterms, resp = "", mix = "", threads = 1,
stan_log_lik_cox <- function(bterms, resp = "", mix = "", threads = NULL,
...) {
p <- stan_log_lik_dpars(bterms, TRUE, resp, mix)
n <- stan_nn(threads)
Expand All @@ -659,7 +659,7 @@ stan_log_lik_cox <- function(bterms, resp = "", mix = "", threads = 1,
}

stan_log_lik_cumulative <- function(bterms, resp = "", mix = "",
threads = 1, ...) {
threads = NULL, ...) {
if (use_glm_primitive(bterms, allow_special_terms = FALSE)) {
p <- args_glm_primitive(bterms$dpars$mu, resp = resp, threads = threads)
out <- sdist("ordered_logistic_glm", p$x, p$beta, p$alpha)
Expand Down Expand Up @@ -779,7 +779,7 @@ stan_log_lik_zero_inflated_negbinomial <- function(bterms, resp = "", mix = "",
}

stan_log_lik_zero_inflated_binomial <- function(bterms, resp = "", mix = "",
threads = 1, ...) {
threads = NULL, ...) {
p <- stan_log_lik_dpars(bterms, TRUE, resp, mix)
n <- stan_nn(threads)
p$trials <- paste0("trials", resp, n)
Expand Down Expand Up @@ -810,8 +810,7 @@ stan_log_lik_zero_inflated_asym_laplace <- function(bterms, resp = "", mix = "",
sdist(lpdf, p$mu, p$sigma, p$quantile, p$zi)
}

stan_log_lik_custom <- function(bterms, resp = "", mix = "", ...) {
# TODO: support reduce_sum
stan_log_lik_custom <- function(bterms, resp = "", mix = "", threads = NULL, ...) {
p <- stan_log_lik_dpars(bterms, TRUE, resp, mix)
family <- bterms$family
dpars <- paste0(family$dpars, mix)
Expand All @@ -821,8 +820,10 @@ stan_log_lik_custom <- function(bterms, resp = "", mix = "", ...) {
}
# insert the response name into the 'vars' strings
# addition terms contain the response in their variable name
n <- stan_nn(threads)
var_names <- sub("\\[.+$", "", family$vars)
var_indices <- get_matches("\\[.+$", family$vars, first = TRUE)
var_indices <- ifelse(var_indices %in% "[n]", n, var_indices)
is_var_adterms <- var_names %in% c("se", "trials", "dec") |
grepl("^((vint)|(vreal))[[:digit:]]+$", var_names)
var_resps <- ifelse(is_var_adterms, resp, "")
Expand Down Expand Up @@ -866,7 +867,7 @@ use_glm_primitive <- function(bterms, allow_special_terms = TRUE) {
# @param bterms a btl object
# @param resp optional name of the response variable
# @return a named list of Stan code snippets
args_glm_primitive <- function(bterms, resp = "", threads = 1) {
args_glm_primitive <- function(bterms, resp = "", threads = NULL) {
stopifnot(is.btl(bterms))
decomp <- get_decomp(bterms$fe)
center_X <- stan_center_X(bterms)
Expand Down
59 changes: 57 additions & 2 deletions R/stanvars.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
#' @param position Name of the position within the block where the
#' Stan code should be placed. Currently allowed are \code{'start'}
#' (the default) and \code{'end'} of the block.
#' @param pll_args Optional Stan code to be put into the header
#' of \code{partial_log_lik} functions. This ensures that the variables
#' specified in \code{scode} can be used in the likelihood even when
#' within-chain parallelization is activated via \code{\link{threading}}.
#'
#' @return An object of class \code{stanvars}.
#'
Expand All @@ -46,10 +50,18 @@
#' block = "parameters")
#' make_stancode(count ~ Trt + zBase, epilepsy,
#' prior = bprior, stanvars = stanvars)
#'
#' # ensure that 'tau' is passed to the likelihood of a threaded model
#' # not necessary for this example but may be necessary in other cases
#' stanvars <- stanvar(scode = "real<lower=0> tau;",
#' block = "parameters", pll_args = "real tau")
#' make_stancode(count ~ Trt + zBase, epilepsy,
#' stanvars = stanvars, threads = threading(2))
#'
#' @export
stanvar <- function(x = NULL, name = NULL, scode = NULL,
block = "data", position = "start") {
block = "data", position = "start",
pll_args = NULL) {
vblocks <- c(
"data", "tdata", "parameters", "tparameters",
"model", "genquant", "functions"
Expand Down Expand Up @@ -99,6 +111,34 @@ stanvar <- function(x = NULL, name = NULL, scode = NULL,
}
scode <- paste0(" ", scode, ";")
}
if (is.null(pll_args)) {
# infer pll_args from x
if (is.integer(x)) {
if (length(x) == 1L) {
pll_type <- "int"
} else {
pll_type <- "int[]"
}
} else if (is.vector(x)) {
if (length(x) == 1L) {
pll_type <- "real"
} else {
pll_type <- "vector"
}
} else if (is.array(x)) {
if (length(dim(x)) == 1L) {
pll_type <- "vector"
} else if (is.matrix(x)) {
pll_type <- "matrix"
}
}
if (!is.null(pll_type)) {
pll_args <- paste0(pll_type, " ", name)
} else {
# don't throw an error because most people will not use threading
pll_args <- character(0)
}
}
} else {
x <- NULL
if (is.null(name)) {
Expand All @@ -108,11 +148,13 @@ stanvar <- function(x = NULL, name = NULL, scode = NULL,
if (is.null(scode)) {
stop2("Argument 'scode' is required if block is not 'data'.")
}
scode <- as.character(scode)
pll_args <- as.character(pll_args)
}
if (position == "end" && block %in% c("functions", "data", "model")) {
stop2("Position '", position, "' is not sensible for block '", block, "'.")
}
out <- nlist(name, sdata = x, scode, block, position)
out <- nlist(name, sdata = x, scode, block, position, pll_args)
structure(setNames(list(out), name), class = "stanvars")
}

Expand Down Expand Up @@ -142,6 +184,19 @@ collapse_stanvars <- function(x, block = NULL, position = NULL) {
collapse(ulapply(x, "[[", "scode"), "\n")
}

# collapse partial lpg-lik code provided in a stanvars object
collapse_stanvars_pll_args <- function(x) {
x <- validate_stanvars(x)
if (!length(x)) {
return(character(0))
}
out <- ulapply(x, "[[", "pll_args")
if (!length(out)) {
return("")
}
collapse(", ", out)
}

# validate 'stanvars' objects
validate_stanvars <- function(x, stan_funs = NULL) {
if (is.null(x)) {
Expand Down
15 changes: 14 additions & 1 deletion man/stanvar.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 17 additions & 1 deletion tests/testthat/tests.make_stancode.R
Original file line number Diff line number Diff line change
Expand Up @@ -1938,7 +1938,7 @@ test_that("argument 'stanvars' is handled correctly", {
expect_match2(scode, "vector[K] M;")
expect_match2(scode, "matrix[K, K] V;")

# define a hierachical prior on the regression coefficients
# define a hierarchical prior on the regression coefficients
bprior <- set_prior("normal(0, tau)", class = "b") +
set_prior("target += normal_lpdf(tau | 0, 10)", check = FALSE)
stanvars <- stanvar(scode = "real<lower=0> tau;",
Expand All @@ -1948,6 +1948,22 @@ test_that("argument 'stanvars' is handled correctly", {
expect_match2(scode, "real<lower=0> tau;")
expect_match2(scode, "target += normal_lpdf(b | 0, tau);")

# ensure that variables are passed to the likelihood of a threaded model
foo <- 0.5
stanvars <- stanvar(foo) +
stanvar(scode = "real<lower=0> tau;",
block = "parameters", pll_args = "real tau")

scode <- make_stancode(count ~ 1, epilepsy, family = poisson(),
stanvars = stanvars, threads = threading(2),
parse = FALSE)
expect_match2(scode,
"partial_log_lik_lpmf(int[] seq, int start, int end, int[] Y, real Intercept, real foo, real tau)"
)
expect_match2(scode,
"reduce_sum(partial_log_lik_lpmf, seq, grainsize, Y, Intercept, foo, tau)"
)

# add transformation at the end of a block
stanvars <- stanvar(scode = " r_1_1 = r_1_1 * 2;",
block = "tparameters", position = "end")
Expand Down

0 comments on commit ce57769

Please sign in to comment.