Skip to content

Commit

Permalink
Add option to save and time fit results per replicate (#197)
Browse files Browse the repository at this point in the history
* add time taken for each fitted replicate, evaluator, and visualizer

* add `save_in_bulk` option in `create_experiment()` to allow for saving/caching results per fitted replicate, evaluator, and visualizer

other improvements/bug fixes: 
* use rbind instead of dplyr::bind_rows() in `compare_tibble_rows()` to avoid error due to mismatching column types
* streamline use of `simplify_tibble()`
* save results with only .rep, .dgp_name, .method_name simplified
* update tests and docs
* export `simplify_tibble()`

* closes #167 and #197
  • Loading branch information
tiffanymtang authored Jan 5, 2025
1 parent 15eaecc commit edaa82c
Show file tree
Hide file tree
Showing 18 changed files with 925 additions and 224 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ export(save_experiment)
export(set_doc_options)
export(set_rmd_options)
export(set_save_dir)
export(simplify_tibble)
export(summarize_feature_importance)
export(summarize_feature_selection_curve)
export(summarize_feature_selection_err)
Expand Down
109 changes: 93 additions & 16 deletions R/experiment-utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ maybe_add_debug_data <- function(tbl, debug = FALSE) {
#' Distribute simulation computation by replicates.
#'
#' @keywords internal
compute_rep <- function(n_reps,
compute_rep <- function(reps,
future.globals,
future.packages,
future.seed,
Expand Down Expand Up @@ -113,11 +113,18 @@ compute_rep <- function(n_reps,
}

# progress updates
total_reps <- n_reps * length(dgp_params_list) * length(method_params_list)
total_reps <- length(reps) * length(dgp_params_list) * length(method_params_list)
p <- maybe_progressr(steps = total_reps,
envir = parent.frame())

results <- future.apply::future_replicate(n_reps, {
vary_param_names <- purrr::map(
c(dgp_params_list, method_params_list),
~ names(.x)
) |>
purrr::reduce(c) |>
unique()

results <- future.apply::future_lapply(as.character(reps), function(i) {

# make a local binding to error_state
error_state <- error_state
Expand All @@ -143,6 +150,15 @@ compute_rep <- function(n_reps,
gc()
})

save_file <- file.path(
save_dir, "fit_results", sprintf("fit_result%s.rds", i)
)
if (use_cached && file.exists(save_file) && !save_in_bulk) {
cached_results <- readRDS(save_file)
} else {
cached_results <- NULL
}

dgp_res <- purrr::list_rbind(purrr::map(
dgp_params_list,
function(dgp_params) {
Expand Down Expand Up @@ -198,7 +214,8 @@ compute_rep <- function(n_reps,
}

return(
list(.dgp = dgp_list[[dgp_name]],
list(.rep = i,
.dgp = dgp_list[[dgp_name]],
.dgp_name = dgp_name,
.dgp_params = dgp_params,
.method = NULL,
Expand All @@ -207,6 +224,7 @@ compute_rep <- function(n_reps,
.method_output = NULL,
.err = data_list) |>
list_to_tibble_row() |>
simplify_tibble(cols = c(".rep", ".dgp_name", ".method_name")) |>
maybe_add_debug_data(TRUE)
)
}
Expand Down Expand Up @@ -240,14 +258,38 @@ compute_rep <- function(n_reps,
method_params = method_params,
duplicate_param_names = duplicate_param_names
) |>
list_to_tibble_row()
list_to_tibble_row() |>
simplify_tibble(cols = c(".rep", ".dgp_name", ".method_name"))

# param_df$.seed <- seed

method_params$.method_name <- NULL
method_params$data_list <- data_list
method_params$.simplify <- FALSE

if (use_cached && file.exists(save_file) && !save_in_bulk) {
is_cached <- compare_tibble_rows(
param_df,
cached_results |>
dplyr::select(tidyselect::all_of(colnames(param_df))),
op = "contained_in"
) &&
compare_tibble_rows(
param_df,
cached_fit_params |>
dplyr::select(tidyselect::all_of(colnames(param_df))),
op = "contained_in"
)
if (is_cached) {
# if (verbose >= 1) {
# inform(sprintf("Found cached results for rep=%s for", i))
# inform(str(simplify_tibble(param_df)))
# }
return(NULL)
}
}

fit_start_time <- Sys.time()
result <- do_call_wrapper(
method_name,
method_list[[method_name]]$fit,
Expand All @@ -256,6 +298,7 @@ compute_rep <- function(n_reps,
# hard-coded method fun call for error messages
call = rlang::call2(paste0(method_name, "$method_fun(...)"))
)
fit_time <- difftime(Sys.time(), fit_start_time, units = "mins")

if ("error" %in% class(result)) {

Expand All @@ -270,7 +313,8 @@ compute_rep <- function(n_reps,
method_params$data_list <- NULL

return(
list(.dgp = dgp_list[[dgp_name]],
list(.rep = i,
.dgp = dgp_list[[dgp_name]],
.dgp_name = dgp_name,
.dgp_params = dgp_params,
.method = method_list[[method_name]],
Expand All @@ -279,6 +323,7 @@ compute_rep <- function(n_reps,
.method_output = NULL,
.err = result) |>
list_to_tibble_row() |>
simplify_tibble(cols = c(".rep", ".dgp_name", ".method_name")) |>
maybe_add_debug_data(TRUE)
)
}
Expand All @@ -304,7 +349,8 @@ compute_rep <- function(n_reps,
method_params$data_list <- NULL

return(
list(.dgp = dgp_list[[dgp_name]],
list(.rep = i,
.dgp = dgp_list[[dgp_name]],
.dgp_name = dgp_name,
.dgp_params = dgp_params,
.method = method_list[[method_name]],
Expand All @@ -313,12 +359,18 @@ compute_rep <- function(n_reps,
.method_output = result,
.err = names_check) |>
list_to_tibble_row() |>
simplify_tibble(cols = c(".rep", ".dgp_name", ".method_name")) |>
maybe_add_debug_data(TRUE)
)
}

result <- result |>
tibble::add_column(param_df, .before = 1)
tibble::add_column(param_df, .before = 1) |>
tibble::add_column(.rep = i, .before = 1)

if (record_time) {
result$.time_taken <- fit_time
}

p("of total reps")

Expand All @@ -332,16 +384,41 @@ compute_rep <- function(n_reps,
}
)) # dgp_res <- purrr::list_rbind(purrr::map(

if (use_cached && file.exists(save_file) && !save_in_bulk) {
dgp_res <- get_matching_rows(
id = cached_fit_params, x = cached_results
) |>
dplyr::bind_rows(dgp_res)
}

if (save_per_rep) {
if (".err" %in% colnames(dgp_res)) {
saveRDS(
dgp_res,
stringr::str_replace(save_file, "\\.rds$", "_error.rds")
)
} else {
saveRDS(dgp_res, save_file)
}
dgp_res <- dgp_res |>
dplyr::select(tidyselect::any_of(unique(c(
".rep", ".dgp", ".dgp_name", ".dgp_params",
".method", ".method_name", ".method_params",
".method_output",
vary_param_names, duplicate_param_names,
".err", ".pid", ".gc"
))))
}

return(dgp_res)

},
simplify = FALSE,
future.globals = future.globals,
future.packages = future.packages,
future.seed = future.seed,
...) # results <- future.apply::future_replicate(

results <- dplyr::bind_rows(results, .id = ".rep")
results <- dplyr::bind_rows(results)

if (debug) {

Expand Down Expand Up @@ -377,7 +454,7 @@ compute_rep <- function(n_reps,
#' Distribute simulation computation by DGPs.
#'
#' @keywords internal
compute_dgp <- function(n_reps,
compute_dgp <- function(reps,
future.globals,
future.packages,
future.seed,
Expand All @@ -389,7 +466,7 @@ compute_dgp <- function(n_reps,
#' Distribute simulation computation by Methods.
#'
#' @keywords internal
compute_method <- function(n_reps,
compute_method <- function(reps,
future.globals,
future.packages,
future.seed,
Expand All @@ -401,7 +478,7 @@ compute_method <- function(n_reps,
#' Doubly nested distributed simulation computation nested by DGPs and reps.
#'
#' @keywords internal
compute_dgp_rep <- function(n_reps,
compute_dgp_rep <- function(reps,
future.globals,
future.packages,
future.seed,
Expand All @@ -413,7 +490,7 @@ compute_dgp_rep <- function(n_reps,
#' Doubly nested distributed simulation computation nested by Methods and reps.
#'
#' @keywords internal
compute_method_rep <- function(n_reps,
compute_method_rep <- function(reps,
future.globals,
future.packages,
future.seed,
Expand All @@ -425,7 +502,7 @@ compute_method_rep <- function(n_reps,
#' Doubly nested distributed simulation computation nested by DGPs and Methods.
#'
#' @keywords internal
compute_dgp_method <- function(n_reps,
compute_dgp_method <- function(reps,
future.globals,
future.packages,
future.seed,
Expand All @@ -438,7 +515,7 @@ compute_dgp_method <- function(n_reps,
#' reps.
#'
#' @keywords internal
compute_dgp_method_reps <- function(n_reps,
compute_dgp_method_reps <- function(reps,
future.globals,
future.packages,
future.seed,
Expand Down
Loading

0 comments on commit edaa82c

Please sign in to comment.