Skip to content

Commit

Permalink
Merge pull request #1717 from mcol/issue_1713
Browse files Browse the repository at this point in the history
Add the brms_seed to loo_R2
  • Loading branch information
paul-buerkner authored Dec 17, 2024
2 parents a0602ad + cfd471d commit 9d1acf2
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 1 deletion.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

* Fit extended-support Beta models via family `xbeta`
thanks to Ioannis Kosmidis. (#1698)
* Add a `seed` argument to `loo_R2` thanks to Marco Colombo. (#1713)

### Bug Fixes

Expand Down
15 changes: 14 additions & 1 deletion R/loo_predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ E_loo_value <- function(x, psis_object, type = "mean", probs = 0.5) {
#' @aliases loo_R2
#'
#' @inheritParams bayes_R2.brmsfit
#' @param seed Optional integer used to initialize the random number
#' generator.
#' @param ... Further arguments passed to
#' \code{\link[brms:posterior_epred.brmsfit]{posterior_epred}} and
#' \code{\link[brms:log_lik.brmsfit]{log_lik}},
Expand Down Expand Up @@ -213,7 +215,8 @@ E_loo_value <- function(x, psis_object, type = "mean", probs = 0.5) {
#' @export loo_R2
#' @export
loo_R2.brmsfit <- function(object, resp = NULL, summary = TRUE,
robust = FALSE, probs = c(0.025, 0.975), ...) {
robust = FALSE, probs = c(0.025, 0.975),
seed = NULL, ...) {
contains_draws(object)
object <- restructure(object)
resp <- validate_resp(resp, object)
Expand All @@ -239,6 +242,16 @@ loo_R2.brmsfit <- function(object, resp = NULL, summary = TRUE,
"'loo_R2' which is likely invalid for ordinal families."
)
}

# set the random seed if required
if (!is.null(seed)) {
if (exists(".Random.seed", envir = .GlobalEnv)) {
rng_state_old <- get(".Random.seed", envir = .GlobalEnv)
on.exit(assign(".Random.seed", rng_state_old, envir = .GlobalEnv))
}
set.seed(seed)
}

args_y <- list(object, warn = TRUE, ...)
args_ypred <- list(object, sort = TRUE, ...)
R2 <- named_list(paste0("R2", resp))
Expand Down
4 changes: 4 additions & 0 deletions man/loo_R2.brmsfit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 9d1acf2

Please sign in to comment.