Skip to content

Commit

Permalink
Add gr timings and fix time.total
Browse files Browse the repository at this point in the history
  • Loading branch information
Cole-Monnahan-NOAA committed Jun 11, 2024
1 parent c6ffdef commit 8e2ee2d
Showing 1 changed file with 22 additions and 3 deletions.
25 changes: 22 additions & 3 deletions R/sparse.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
#' @return A fitted MCMC object of class 'adfit'
#' @export
sample_sparse_tmb <- function(obj, iter, warmup, cores, chains,
control=NULL, seed=NULL, metric=c('sparse','dense','diag')){
control=NULL, seed=NULL,
metric=c('sparse','dense','diag')){
iter <- iter-warmup
metric <- match.arg(metric)
obj$env$beSilent()
Expand All @@ -35,6 +36,7 @@ sample_sparse_tmb <- function(obj, iter, warmup, cores, chains,
## rebuild without random effects
mydll <- unclass(getLoadedDLLs()[[obj$env$DLL]])$path
isRTMB <- ifelse(obj$env$DLL=='RTMB', TRUE, FALSE)
message("Rebuilding obj without random effects...")
if(!isRTMB){
obj2 <- TMB::MakeADFun(data=obj$env$data, parameters=obj$env$parList(),
map=obj$env$map,
Expand All @@ -58,16 +60,31 @@ sample_sparse_tmb <- function(obj, iter, warmup, cores, chains,
} else {
packages = c("RTMB", "Matrix")
}
message("Starting MCMC sampling...")
fit <- stan_sample(fn=fsparse, par_inits=initssparse,
grad_fun=gsparse, num_samples=iter,
num_warmup=warmup,
globals = globals, packages=packages,
adapt_delta=control$adapt_delta,
parallel_chains=cores, save_warmup=TRUE,
num_chains = chains, seed = seed)

fit2 <- as.tmbfit(fit, mle=mle, invf=finv)
fit2$time.Q <- time.Q; fit2$time.Qinv <- time.Qinv
## gradient timings to check for added overhead
if(require(microbenchmark)){
bench <- microbenchmark(obj2$gr(init),
gsparse(initssparse),
times=500, unit='s')
fit2$time.gr <- summary(bench)$median[1]
fit2$time.gr2 <- summary(bench)$median[2]
} else {
warning("Package microbenchmark required to do accurate gradient timings, using system.time() instead")
fit2$time.gr <-
as.numeric(system.time(trash <- replicate(1000, obj2$gr(init)))[3])
fit2$time.gr2 <-
as.numeric(system.time(trash <- replicate(1000, gsparse(initssparse)))[3])
}
fit2$metric <- metric
print(fit2)
fit2
}
Expand Down Expand Up @@ -124,8 +141,10 @@ as.tmbfit <- function(x, mle, invf){
par_names=mle$parnames,
max_treedepth=x@metadata$max_depth,
warmup=as.numeric(x@metadata$num_warmup),
time.warmup=timing[1,], time.total=timing[1,]+timing[,2],
time.warmup=timing[1,], time.total=timing[1,]+timing[2,],
## iter=as.numeric(x@metadata$num_samples)+as.numeric(x@metadata$num_warmup),
algorithm='NUTS')
adfit(x)
}


0 comments on commit 8e2ee2d

Please sign in to comment.