-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
17 changed files
with
350 additions
and
78 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
# Generated by roxygen2: do not edit by hand | ||
|
||
export(StanModel) | ||
export(compile_model) | ||
export(set_bridgestan_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
IS_WINDOWS <- isTRUE(.Platform$OS.type == "windows") | ||
MAKE <- Sys.getenv("MAKE", ifelse(IS_WINDOWS, "mingw32-make", "make")) | ||
|
||
|
||
#' Get the path to BridgeStan. | ||
#' | ||
#' By default this is set to the value of the environment | ||
#' variable `BRIDGESTAN`. | ||
#' | ||
#' If there is no path set, this function will download | ||
#' a matching version of BridgeStan to a folder called | ||
#' `.bridgestan` in the user's home directory. | ||
#' | ||
#' See also `set_bridgestan_path` | ||
verify_bridgestan_path <- function(path) { | ||
suppressWarnings({ | ||
folder <- normalizePath(path) | ||
}) | ||
if (!dir.exists(folder)) { | ||
stop(paste0("BridgeStan folder '", folder, "' does not exist!\n", "If you need to set a different location, call 'set_bridgestan_path()'")) | ||
} | ||
makefile <- file.path(folder, "Makefile") | ||
if (!file.exists(makefile)) { | ||
stop(paste0("BridgeStan folder '", folder, "' does not contain file 'Makefile',", | ||
" please ensure it is built properly!\n", "If you need to set a different location, call 'set_bridgestan_path()'")) | ||
} | ||
} | ||
|
||
#' Set the path to BridgeStan. | ||
#' | ||
#' This should point to the top-level folder of the repository. | ||
#' @export | ||
set_bridgestan_path <- function(path) { | ||
verify_bridgestan_path(path) | ||
Sys.setenv(BRIDGESTAN = normalizePath(path)) | ||
} | ||
|
||
get_bridgestan_path <- function() { | ||
# try to get from environment | ||
path <- Sys.getenv("BRIDGESTAN") | ||
if (path == "") { | ||
path <- CURRENT_BRIDGESTAN | ||
tryCatch({ | ||
verify_bridgestan_path(path) | ||
}, error = function(e) { | ||
print(paste0("Bridgestan not found at location specified by $BRIDGESTAN ", | ||
"environment variable, downloading version ", packageVersion("bridgestan"), | ||
" to ", path)) | ||
get_bridgestan_src() | ||
}) | ||
} | ||
|
||
return(path) | ||
} | ||
|
||
|
||
#' Run BridgeStan's Makefile on a `.stan` file, creating the `.so` | ||
#' used by the StanModel class. | ||
#' This function assumes that the path to BridgeStan is valid. | ||
#' This can be set with `set_bridgestan_path`. | ||
#' | ||
#' @param stan_file A path to a Stan model file. | ||
#' @param stanc_arg A list of arguments to pass to stanc3. | ||
#' For example, `c('--O1')` will enable compiler optimization level 1. | ||
#' @param make_args A list of additional arguments to pass to Make. | ||
#' For example, `c('STAN_THREADS=True')` will enable | ||
#' threading for the compiled model. If the same flags are defined | ||
#' in `make/local`, the versions passed here will take precedent. | ||
#' @return Path to the compiled model. | ||
#' @export | ||
compile_model <- function(stan_file, stanc_args = NULL, make_args = NULL) { | ||
verify_bridgestan_path(get_bridgestan_path()) | ||
suppressWarnings({ | ||
file_path <- normalizePath(stan_file) | ||
}) | ||
if (tools::file_ext(file_path) != "stan") { | ||
stop(paste0("File '", file_path, "' does not end with '.stan'")) | ||
} | ||
if (!file.exists(file_path)) { | ||
stop(paste0("File '", file_path, "' does not exist!")) | ||
} | ||
|
||
output <- paste0(tools::file_path_sans_ext(file_path), "_model.so") | ||
stancflags <- paste("--include-paths=.", paste(stanc_args, collapse = " ")) | ||
|
||
flags <- c(paste("-C", get_bridgestan_path()), make_args, paste0("STANCFLAGS=\"", | ||
stancflags, "\""), output) | ||
|
||
suppressWarnings({ | ||
res <- system2(MAKE, args = flags, stdout = TRUE, stderr = TRUE) | ||
}) | ||
res_attrs <- attributes(res) | ||
if ("status" %in% names(res_attrs) && res_attrs$status != 0) { | ||
stop(paste0("Compilation failed with error code ", res_attrs$status, "\noutput:\n", | ||
paste(res, collapse = "\n"))) | ||
} | ||
|
||
return(output) | ||
} | ||
|
||
windows_path_setup <- function() { | ||
if (.Platform$OS.type == "windows") { | ||
suppressWarnings(out <- system2("where.exe", "tbb.dll", stdout = NULL, stderr = NULL)) | ||
if (out != 0) { | ||
tbb_path <- file.path(get_bridgestan_path(), "stan", "lib", "stan_math", | ||
"lib", "tbb") | ||
Sys.setenv(PATH = paste(tbb_path, Sys.getenv("PATH"), sep = ";")) | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
current_version <- packageVersion("bridgestan") | ||
HOME_BRIDGESTAN <- path.expand(file.path("~", ".bridgestan")) | ||
CURRENT_BRIDGESTAN <- file.path(HOME_BRIDGESTAN, paste0("bridgestan-", current_version)) | ||
|
||
RETRIES <- 5 | ||
|
||
get_bridgestan_src <- function() { | ||
url <- paste0("https://github.com/roualdes/bridgestan/releases/download/", "v", | ||
current_version, "/bridgestan-", current_version, ".tar.gz") | ||
|
||
dir.create(HOME_BRIDGESTAN, showWarnings = FALSE, recursive = TRUE) | ||
temp <- tempfile() | ||
err_text <- paste("Failed to download Bridgestan", current_version, "from github.com.") | ||
for (i in 1:RETRIES) { | ||
tryCatch({ | ||
download.file(url, destfile = temp, mode = "wb", quiet = TRUE, method = "auto") | ||
}, error = function(e) { | ||
cat(err_text, "\n") | ||
if (i == RETRIES) { | ||
stop(err_text, call. = FALSE) | ||
} else { | ||
cat("Retrying (", i + 1, "/", RETRIES, ")...\n", sep = "") | ||
Sys.sleep(1) | ||
} | ||
}) | ||
} | ||
|
||
tryCatch({ | ||
untar(temp, exdir = HOME_BRIDGESTAN) | ||
}, error = function(e) { | ||
stop(paste("Failed to unpack", url, "during installation"), call. = FALSE) | ||
}) | ||
|
||
unlink(temp) | ||
} |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
base = "../../.." | ||
|
||
load_model <- function(name, include_data = TRUE) { | ||
if (include_data) { | ||
data = file.path(base, "test_models", name, paste0(name, ".data.json")) | ||
} else { | ||
data = "" | ||
} | ||
model <- StanModel$new(file.path(base, "test_models", name, paste0(name, "_model.so")), | ||
data, 1234) | ||
return(model) | ||
} | ||
|
||
simple <- load_model("simple") | ||
bernoulli <- load_model("bernoulli") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
|
||
test_that("loading another library didn't break prior ones", { | ||
if (.Platform$OS.type == "windows") { | ||
dll = "./test_collisions.dll" | ||
} else { | ||
dll = "./test_collisions.so" | ||
} | ||
if (file.exists(dll)) { | ||
dyn.load(dll) | ||
expect_equal(bernoulli$name(), "bernoulli_model") | ||
expect_equal(simple$name(), "simple_model") | ||
} | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
|
||
|
||
test_that("compilation works", { | ||
name <- "multi" | ||
file <- file.path(base, "test_models", name, paste0(name, ".stan")) | ||
|
||
lib <- file.path(base, "test_models", name, paste0(name, "_model.so")) | ||
unlink(lib, force = TRUE) | ||
|
||
out <- compile_model(file, stanc_args = c("--O1")) | ||
|
||
expect_true(file.exists(lib)) | ||
expect_equal(normalizePath(lib), normalizePath(out)) | ||
|
||
unlink(lib, force = TRUE) | ||
|
||
out <- compile_model(file, make_args = c("STAN_THREADS=True")) | ||
}) | ||
|
||
test_that("compilation fails on non-stan file", { | ||
expect_error(compile_model(file.path(base, "test_models", "simple", "simple.data.json")), | ||
"does not end with '.stan'") | ||
}) | ||
|
||
test_that("compilation fails on missing file", { | ||
expect_error(compile_model("badpath.stan"), "does not exist!") | ||
}) | ||
|
||
test_that("compilation fails on bad syntax", { | ||
expect_error(compile_model(file.path(base, "test_models", "syntax_error", "syntax_error.stan")), | ||
"Compilation failed") | ||
}) | ||
|
||
test_that("bad paths fail", { | ||
expect_error(set_bridgestan_path("badpath"), "does not exist!") | ||
expect_error(set_bridgestan_path(file.path(base, "test_models")), "does not contain file 'Makefile'") | ||
}) |
Oops, something went wrong.