Skip to content

Commit

Permalink
Support compilation inside R code
Browse files Browse the repository at this point in the history
  • Loading branch information
WardBrian committed Nov 8, 2023
1 parent 9fecdb3 commit b9fefc4
Show file tree
Hide file tree
Showing 17 changed files with 350 additions and 78 deletions.
10 changes: 4 additions & 6 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,6 @@ jobs:
path: ./test_models/
key: ${{ hashFiles('**/*.stan', 'src/*', 'stan/src/stan/version.hpp', 'Makefile') }}-${{ matrix.os }}-v${{ env.CACHE_VERSION }}

# needed for R tests until they have compilation utilities and can set this themselves.
- name: Set up TBB
if: matrix.os == 'windows-latest'
run: |
Add-Content $env:GITHUB_PATH "$(pwd)/stan/lib/stan_math/lib/tbb"
- name: Run tests
if: matrix.os != 'windows-latest'
run: |
Expand All @@ -231,6 +225,8 @@ jobs:
Rscript -e "devtools::test(reporter = c(\"summary\", \"fail\"))"
Rscript -e "install.packages(getwd(), repos=NULL, type=\"source\")"
Rscript example.R
env:
BRIDGESTAN: ${{ github.workspace }}

- name: Run tests (windows)
if: matrix.os == 'windows-latest'
Expand All @@ -241,6 +237,8 @@ jobs:
Rscript -e 'devtools::test(reporter = c(\"summary\", \"fail\"))'
Rscript -e 'install.packages(getwd(), repos=NULL, type=\"source\")'
Rscript example.R
env:
BRIDGESTAN: ${{ github.workspace }}

rust:
needs: [build]
Expand Down
2 changes: 2 additions & 0 deletions R/NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Generated by roxygen2: do not edit by hand

export(StanModel)
export(compile_model)
export(set_bridgestan_path)
6 changes: 4 additions & 2 deletions R/R/bridgestan.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ StanModel <- R6::R6Class("StanModel",
#' @return A new StanModel.
initialize = function(lib, data, seed) {
if (.Platform$OS.type == "windows"){
windows_path_setup()
lib_old <- lib
lib <- paste0(tools::file_path_sans_ext(lib), ".dll")
file.copy(from=lib_old, to=lib)
Expand Down Expand Up @@ -75,7 +76,8 @@ StanModel <- R6::R6Class("StanModel",
PACKAGE = private$lib_name
)$info_out
},

#' @description
#' Get the version of BridgeStan used in the compiled model.
model_version= function() {
.C("bs_version_R",
major = as.integer(0),
Expand Down Expand Up @@ -345,7 +347,7 @@ handle_error <- function(lib_name, err_msg, err_ptr, function_name) {
#' StanRNG
#'
#' RNG object for use with `StanModel$param_constrain()`
#' @field rng The pointer to the RNG object.
#' @field ptr The pointer to the RNG object.
#' @keywords internal
StanRNG <- R6::R6Class("StanRNG",
public = list(
Expand Down
110 changes: 110 additions & 0 deletions R/R/compile.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
IS_WINDOWS <- isTRUE(.Platform$OS.type == "windows")
MAKE <- Sys.getenv("MAKE", ifelse(IS_WINDOWS, "mingw32-make", "make"))


#' Get the path to BridgeStan.
#'
#' By default this is set to the value of the environment
#' variable `BRIDGESTAN`.
#'
#' If there is no path set, this function will download
#' a matching version of BridgeStan to a folder called
#' `.bridgestan` in the user's home directory.
#'
#' See also `set_bridgestan_path`
verify_bridgestan_path <- function(path) {
suppressWarnings({
folder <- normalizePath(path)
})
if (!dir.exists(folder)) {
stop(paste0("BridgeStan folder '", folder, "' does not exist!\n", "If you need to set a different location, call 'set_bridgestan_path()'"))
}
makefile <- file.path(folder, "Makefile")
if (!file.exists(makefile)) {
stop(paste0("BridgeStan folder '", folder, "' does not contain file 'Makefile',",
" please ensure it is built properly!\n", "If you need to set a different location, call 'set_bridgestan_path()'"))
}
}

#' Set the path to BridgeStan.
#'
#' This should point to the top-level folder of the repository.
#' @export
set_bridgestan_path <- function(path) {
verify_bridgestan_path(path)
Sys.setenv(BRIDGESTAN = normalizePath(path))
}

get_bridgestan_path <- function() {
# try to get from environment
path <- Sys.getenv("BRIDGESTAN")
if (path == "") {
path <- CURRENT_BRIDGESTAN
tryCatch({
verify_bridgestan_path(path)
}, error = function(e) {
print(paste0("Bridgestan not found at location specified by $BRIDGESTAN ",
"environment variable, downloading version ", packageVersion("bridgestan"),
" to ", path))
get_bridgestan_src()
})
}

return(path)
}


#' Run BridgeStan's Makefile on a `.stan` file, creating the `.so`
#' used by the StanModel class.
#' This function assumes that the path to BridgeStan is valid.
#' This can be set with `set_bridgestan_path`.
#'
#' @param stan_file A path to a Stan model file.
#' @param stanc_arg A list of arguments to pass to stanc3.
#' For example, `c('--O1')` will enable compiler optimization level 1.
#' @param make_args A list of additional arguments to pass to Make.
#' For example, `c('STAN_THREADS=True')` will enable
#' threading for the compiled model. If the same flags are defined
#' in `make/local`, the versions passed here will take precedent.
#' @return Path to the compiled model.
#' @export
compile_model <- function(stan_file, stanc_args = NULL, make_args = NULL) {
verify_bridgestan_path(get_bridgestan_path())
suppressWarnings({
file_path <- normalizePath(stan_file)
})
if (tools::file_ext(file_path) != "stan") {
stop(paste0("File '", file_path, "' does not end with '.stan'"))
}
if (!file.exists(file_path)) {
stop(paste0("File '", file_path, "' does not exist!"))
}

output <- paste0(tools::file_path_sans_ext(file_path), "_model.so")
stancflags <- paste("--include-paths=.", paste(stanc_args, collapse = " "))

flags <- c(paste("-C", get_bridgestan_path()), make_args, paste0("STANCFLAGS=\"",
stancflags, "\""), output)

suppressWarnings({
res <- system2(MAKE, args = flags, stdout = TRUE, stderr = TRUE)
})
res_attrs <- attributes(res)
if ("status" %in% names(res_attrs) && res_attrs$status != 0) {
stop(paste0("Compilation failed with error code ", res_attrs$status, "\noutput:\n",
paste(res, collapse = "\n")))
}

return(output)
}

windows_path_setup <- function() {
if (.Platform$OS.type == "windows") {
suppressWarnings(out <- system2("where.exe", "tbb.dll", stdout = NULL, stderr = NULL))
if (out != 0) {
tbb_path <- file.path(get_bridgestan_path(), "stan", "lib", "stan_math",
"lib", "tbb")
Sys.setenv(PATH = paste(tbb_path, Sys.getenv("PATH"), sep = ";"))
}
}
}
35 changes: 35 additions & 0 deletions R/R/download.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
current_version <- packageVersion("bridgestan")
HOME_BRIDGESTAN <- path.expand(file.path("~", ".bridgestan"))
CURRENT_BRIDGESTAN <- file.path(HOME_BRIDGESTAN, paste0("bridgestan-", current_version))

RETRIES <- 5

get_bridgestan_src <- function() {
url <- paste0("https://github.com/roualdes/bridgestan/releases/download/", "v",
current_version, "/bridgestan-", current_version, ".tar.gz")

dir.create(HOME_BRIDGESTAN, showWarnings = FALSE, recursive = TRUE)
temp <- tempfile()
err_text <- paste("Failed to download Bridgestan", current_version, "from github.com.")
for (i in 1:RETRIES) {
tryCatch({
download.file(url, destfile = temp, mode = "wb", quiet = TRUE, method = "auto")
}, error = function(e) {
cat(err_text, "\n")
if (i == RETRIES) {
stop(err_text, call. = FALSE)
} else {
cat("Retrying (", i + 1, "/", RETRIES, ")...\n", sep = "")
Sys.sleep(1)
}
})
}

tryCatch({
untar(temp, exdir = HOME_BRIDGESTAN)
}, error = function(e) {
stop(paste("Failed to unpack", url, "during installation"), call. = FALSE)
})

unlink(temp)
}
1 change: 1 addition & 0 deletions R/man/StanModel.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 1 addition & 8 deletions R/man/StanRNG.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

31 changes: 31 additions & 0 deletions R/man/compile_model.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions R/man/set_bridgestan_path.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 19 additions & 0 deletions R/man/verify_bridgestan_path.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 0 additions & 8 deletions R/tests/testthat.R
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@
# This file is part of the standard setup for testthat.
# It is recommended that you do not modify it.
#
# Where should you do additional test configuration?
# Learn more about the roles of various files in:
# * https://r-pkgs.org/tests.html
# * https://testthat.r-lib.org/reference/test_package.html#special-files

library(testthat)
library(bridgestan)

Expand Down
15 changes: 15 additions & 0 deletions R/tests/testthat/setup.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
base = "../../.."

load_model <- function(name, include_data = TRUE) {
if (include_data) {
data = file.path(base, "test_models", name, paste0(name, ".data.json"))
} else {
data = ""
}
model <- StanModel$new(file.path(base, "test_models", name, paste0(name, "_model.so")),
data, 1234)
return(model)
}

simple <- load_model("simple")
bernoulli <- load_model("bernoulli")
13 changes: 13 additions & 0 deletions R/tests/testthat/test_collisions.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@

test_that("loading another library didn't break prior ones", {
if (.Platform$OS.type == "windows") {
dll = "./test_collisions.dll"
} else {
dll = "./test_collisions.so"
}
if (file.exists(dll)) {
dyn.load(dll)
expect_equal(bernoulli$name(), "bernoulli_model")
expect_equal(simple$name(), "simple_model")
}
})
37 changes: 37 additions & 0 deletions R/tests/testthat/test_compile.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@


test_that("compilation works", {
name <- "multi"
file <- file.path(base, "test_models", name, paste0(name, ".stan"))

lib <- file.path(base, "test_models", name, paste0(name, "_model.so"))
unlink(lib, force = TRUE)

out <- compile_model(file, stanc_args = c("--O1"))

expect_true(file.exists(lib))
expect_equal(normalizePath(lib), normalizePath(out))

unlink(lib, force = TRUE)

out <- compile_model(file, make_args = c("STAN_THREADS=True"))
})

test_that("compilation fails on non-stan file", {
expect_error(compile_model(file.path(base, "test_models", "simple", "simple.data.json")),
"does not end with '.stan'")
})

test_that("compilation fails on missing file", {
expect_error(compile_model("badpath.stan"), "does not exist!")
})

test_that("compilation fails on bad syntax", {
expect_error(compile_model(file.path(base, "test_models", "syntax_error", "syntax_error.stan")),
"Compilation failed")
})

test_that("bad paths fail", {
expect_error(set_bridgestan_path("badpath"), "does not exist!")
expect_error(set_bridgestan_path(file.path(base, "test_models")), "does not contain file 'Makefile'")
})
Loading

0 comments on commit b9fefc4

Please sign in to comment.