Skip to content

Commit

Permalink
Merge pull request #139 from mrc-ide/mrc-6212
Browse files Browse the repository at this point in the history
Allow domain to be given for DSL models
  • Loading branch information
richfitz authored Feb 11, 2025
2 parents c761d69 + 8b13833 commit f736d2a
Show file tree
Hide file tree
Showing 10 changed files with 155 additions and 22 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: monty
Title: Monte Carlo Models
Version: 0.3.27
Version: 0.3.28
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "rich.fitzjohn@gmail.com"),
person("Wes", "Hinsley", role = "aut"),
Expand Down
28 changes: 27 additions & 1 deletion R/dsl-generate.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,26 @@ dsl_generate <- function(dat) {

dsl_generate_density <- function(dat, env, meta) {
exprs <- lapply(dat$exprs, dsl_generate_density_expr, meta)
body <- c(call("<-", meta[["pars"]], quote(packer$unpack(x))),
body_exprs <- c(call("<-", meta[["pars"]], quote(packer$unpack(x))),
call("<-", meta[["density"]], quote(numeric())),
exprs,
call("sum", meta[["density"]]))
if (is.null(dat$domain)) {
body <- rlang::call2("{", !!!body_exprs)
} else {
if (nrow(dat$domain) == 1) {
domain_min <- dat$domain[[1]]
domain_max <- dat$domain[[2]]
in_domain <- bquote(x >= .(domain_min) && x <= .(domain_max))
} else {
domain_min <- rlang::call2("c", !!!unname(dat$domain[, 1]))
domain_max <- rlang::call2("c", !!!unname(dat$domain[, 2]))
in_domain <- bquote(all(x >= .(domain_min) & x <= .(domain_max)))
}
body <- call(
"{",
call("if", in_domain, rlang::call2("{", !!!body_exprs), call("{", -Inf)))
}
vectorise_density_over_parameters(
as_function(alist(x = ), body, env))
}
Expand Down Expand Up @@ -58,6 +74,9 @@ dsl_generate_gradient <- function(dat, env, meta) {


dsl_generate_direct_sample <- function(dat, env, meta) {
if (!is.null(dat$domain)) {
return(NULL)
}
exprs <- lapply(dat$exprs, dsl_generate_sample_expr, meta)
body <- c(call("<-", meta[["pars"]], quote(list())),
exprs,
Expand Down Expand Up @@ -146,6 +165,13 @@ dsl_generate_domain <- function(dat, meta) {
}
}
}

## Same logic as model_combine_domain
if (!is.null(dat$domain)) {
domain[, 1] <- pmax(domain[, 1], dat$domain[, 1])
domain[, 2] <- pmin(domain[, 2], dat$domain[, 2])
}

domain
}

Expand Down
9 changes: 7 additions & 2 deletions R/dsl-parse.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
## The default of gradient_required = TRUE here helps with tests
dsl_parse <- function(exprs, gradient_required = TRUE, fixed = NULL,
call = NULL) {
domain = NULL, call = NULL) {
exprs <- lapply(exprs, dsl_parse_expr, call)

dsl_parse_check_duplicates(exprs, call)
Expand All @@ -10,9 +10,14 @@ dsl_parse <- function(exprs, gradient_required = TRUE, fixed = NULL,
name <- vcapply(exprs, "[[", "name")
parameters <- name[vcapply(exprs, "[[", "type") == "stochastic"]

if (!is.null(domain)) {
domain <- validate_domain(domain, parameters, call = call)
}

adjoint <- dsl_parse_adjoint(parameters, exprs, gradient_required)

list(parameters = parameters, exprs = exprs, adjoint = adjoint, fixed = fixed)
list(parameters = parameters, exprs = exprs, adjoint = adjoint,
fixed = fixed, domain = domain)
}


Expand Down
20 changes: 16 additions & 4 deletions R/dsl.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@
##' hyperparameters that are fixed across a set of model runs, for
##' example.
##'
##' @param domain An optional domain. Normally this is not wanted,
##' but you can use this to truncate the domain of one or more
##' parameters. The domain is effectively applied *after* the
##' calculations implied by the DSL. The density is not
##' recalculated to reflect the change in the marginal density.
##' Applying a domain will remove the ability to sample from the
##' model, at least for now. See [monty_model] for details on the
##' format. The provided parameters must match the parameters of
##' your model.
##'
##' @return A [monty_model] object derived from the expressions you
##' provide.
##'
Expand All @@ -48,7 +58,8 @@
##'
##' # You can also pass strings
##' monty_dsl("a ~ Normal(0, 1)")
monty_dsl <- function(x, type = NULL, gradient = NULL, fixed = NULL) {
monty_dsl <- function(x, type = NULL, gradient = NULL, fixed = NULL,
domain = NULL) {
quo <- rlang::enquo(x)
if (rlang::quo_is_symbol(quo)) {
x <- rlang::eval_tidy(quo)
Expand All @@ -58,13 +69,14 @@ monty_dsl <- function(x, type = NULL, gradient = NULL, fixed = NULL) {
call <- environment()
fixed <- check_dsl_fixed(fixed)
exprs <- dsl_preprocess(x, type, call)
dat <- dsl_parse(exprs, gradient, fixed, call)
dat <- dsl_parse(exprs, gradient, fixed, domain, call)
dsl_generate(dat)
}



monty_dsl_parse <- function(x, type = NULL, gradient = NULL, fixed = NULL) {
monty_dsl_parse <- function(x, type = NULL, gradient = NULL, fixed = NULL,
domain = NULL) {
call <- environment()
quo <- rlang::enquo(x)
if (rlang::quo_is_symbol(quo)) {
Expand All @@ -74,7 +86,7 @@ monty_dsl_parse <- function(x, type = NULL, gradient = NULL, fixed = NULL) {
}
fixed <- check_dsl_fixed(fixed, call)
exprs <- dsl_preprocess(x, type, call)
dsl_parse(exprs, gradient, fixed, call)
dsl_parse(exprs, gradient, fixed, domain, call)
}


Expand Down
26 changes: 24 additions & 2 deletions R/model-function.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
##' conjunction with `packer` (you should use the `fixed` argument
##' to `monty_packer` instead).
##'
##' @param domain Optional domain, see [monty_model]'s arguments for
##' details.
##'
##' @param allow_multiple_parameters Logical, indicating if passing in
##' vectors for all parameters will return a vector of densities.
##' This is `FALSE` by default because we cannot determine this
Expand All @@ -48,6 +51,7 @@
##' # Same as the built-in banana example:
##' monty_model_density(monty_example("banana"), c(0, 0))
monty_model_function <- function(density, packer = NULL, fixed = NULL,
domain = NULL,
allow_multiple_parameters = FALSE) {
if (!is.function(density)) {
cli::cli_abort("Expected 'density' to be a function", arg = "density")
Expand Down Expand Up @@ -84,10 +88,28 @@ monty_model_function <- function(density, packer = NULL, fixed = NULL,
properties <- monty_model_properties(
allow_multiple_parameters = allow_multiple_parameters)

parameters <- packer$names()

use_domain <- !is.null(domain)
if (use_domain) {
domain <- validate_domain(domain, parameters, call = call)
if (allow_multiple_parameters) {
## This involves some pretty tedious bookkeeping, and is going
## to interact with the interface for running an indexed subset
## of parameters that we need to sort dust out properly.
cli::cli_abort(
"'allow_multiple_parameters' and 'domain' cannot yet be used together")
}
}

monty_model(
list(parameters = packer$names(),
list(parameters = parameters,
density = function(x) {
if (use_domain && !all(x >= domain[, 1] & x <= domain[, 2])) {
return(-Inf)
}
rlang::inject(density(!!!packer$unpack(x), !!!fixed))
}),
},
domain = domain),
properties)
}
27 changes: 16 additions & 11 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -358,22 +358,27 @@ validate_model_parameters <- function(model, call = NULL) {


validate_model_domain <- function(model, call = NULL) {
domain <- model$domain
n_pars <- length(model$parameters)
validate_domain(model$domain, model$parameters, call = call)
}


validate_domain <- function(domain, parameters,
name = deparse(substitute(domain)), call = NULL) {
n_pars <- length(parameters)

if (is.null(domain)) {
domain <- cbind(rep(-Inf, n_pars), rep(Inf, n_pars))
rownames(domain) <- model$parameters
rownames(domain) <- parameters
return(domain)
}

if (!is.matrix(domain)) {
cli::cli_abort("Expected 'model$domain' to be a matrix if non-NULL",
cli::cli_abort("Expected '{name}' to be a matrix if non-NULL",
call = call)
}
if (ncol(domain) != 2) {
cli::cli_abort(
c(paste("Expected 'model$domain' to have 2 columns,",
c(paste("Expected '{name}' to have 2 columns,",
"but it had {ncol(domain)}"),
i = paste("Because your domain is unnamed, if given it must",
"include all parameters in the same order as your model")),
Expand All @@ -384,30 +389,30 @@ validate_model_domain <- function(model, call = NULL) {
if (is.null(nms)) {
if (nrow(domain) != n_pars) {
cli::cli_abort(
paste("Expected 'model$domain' to have {n_pars} row{?s},",
paste("Expected '{name}' to have {n_pars} row{?s},",
"but it had {nrow(domain)}"),
call = call)
}
rownames(domain) <- model$parameters
rownames(domain) <- parameters
} else {
## We might treat parameters that begin with '[' specially and
## allow these to replicate. So if the user has a[1], a[2],
## a[3] then a row with 'a' will apply across all of these that
## are not explicitly given.
err <- setdiff(nms, model$parameters)
err <- setdiff(nms, parameters)
if (length(err) > 0) {
cli::cli_abort(
c("Unexpected parameters found in 'model$domain' rownames",
c("Unexpected parameters found in '{name}' rownames",
set_names(err, "x")),
call = call)
}
msg <- setdiff(model$parameters, nms)
msg <- setdiff(parameters, nms)
if (length(msg) > 0) {
extra <- cbind(rep(-Inf, length(msg)), rep(Inf, length(msg)))
rownames(extra) <- msg
domain <- rbind(domain, extra)
}
domain <- domain[model$parameters, , drop = FALSE]
domain <- domain[parameters, , drop = FALSE]
}

domain
Expand Down
12 changes: 11 additions & 1 deletion man/monty_dsl.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions man/monty_model_function.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 24 additions & 0 deletions tests/testthat/test-dsl.R
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,27 @@ test_that("gradient calculation correct single-parameter model", {
m$gradient(x[, 1, drop = FALSE]),
cbind(m$gradient(x[, 1])))
})


test_that("can apply a domain", {
m <- monty_dsl({
a ~ Exponential(1)
},
domain = rbind(a = c(1, 4)))
expect_equal(m$density(2), -2)
expect_equal(m$density(0), -Inf)
expect_equal(m$density(6), -Inf)
})


test_that("can apply a domain", {
m <- monty_dsl({
a ~ Exponential(1)
b ~ Normal(0, 1)
},
domain = rbind(a = c(1, 4), b = c(-2, 2)))
expect_equal(m$density(c(2, 1)),
dexp(2, log = TRUE) + dnorm(1, log = TRUE))
expect_equal(m$density(c(0, 1)), -Inf)
expect_equal(m$density(c(1, 6)), -Inf)
})
25 changes: 25 additions & 0 deletions tests/testthat/test-model-function.R
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,28 @@ test_that("can't use process in packer with multiple parameters", {
monty_model_function(fn, packer = p, allow_multiple_parameters = TRUE),
"Can't use 'allow_multiple_parameters' with a packer")
})


test_that("can apply domain to model from function", {
fn <- function(a, b) {
dnorm(0, a, b, log = TRUE)
}
m <- monty_model_function(fn, domain = rbind(b = c(1, 5)))
expect_s3_class(m, "monty_model")
expect_equal(m$domain, rbind(a = c(-Inf, Inf), b = c(1, 5)))
expect_equal(m$parameters, c("a", "b"))
expect_equal(monty_model_density(m, c(1, 2)),
dnorm(0, 1, 2, log = TRUE))
expect_equal(monty_model_density(m, c(1, 6)), -Inf)
})


test_that("cannot use domain and multiple parameters", {
fn <- function(a, b) {
dnorm(0, a, b, log = TRUE)
}
domain <- rbind(b = c(1, 5))
expect_error(
monty_model_function(fn, domain = domain, allow_multiple_parameters = TRUE),
"'allow_multiple_parameters' and 'domain' cannot yet be used together")
})

0 comments on commit f736d2a

Please sign in to comment.