Skip to content

Commit

Permalink
Fix render_docs when save_in_bulk = FALSE (#200)
Browse files Browse the repository at this point in the history
* add `get_save_in_bulk()` method to `Experiment` class
* update render_docs to work when `save_in_bulk = FALSE`
  • Loading branch information
tiffanymtang authored Jan 6, 2025
1 parent e31617d commit 5f7844b
Show file tree
Hide file tree
Showing 5 changed files with 235 additions and 38 deletions.
18 changes: 16 additions & 2 deletions R/experiment.R
Original file line number Diff line number Diff line change
Expand Up @@ -1121,8 +1121,10 @@ Experiment <- R6::R6Class(
} else {
save_dir <- private$.get_vary_across_dir()
}
if (!dir.exists(file.path(save_dir, "fit_results"))) {
dir.create(file.path(save_dir, "fit_results"), recursive = TRUE)
if (!save_in_bulk) {
if (!dir.exists(file.path(save_dir, "fit_results"))) {
dir.create(file.path(save_dir, "fit_results"), recursive = TRUE)
}
}

dgp_list <- private$.get_obj_list("dgp")
Expand Down Expand Up @@ -2271,6 +2273,18 @@ Experiment <- R6::R6Class(
invisible(self)
},

#' @description Get the `save_in_bulk` parameter for the `Experiment`.
#'
#' @return Logical, indicating whether the results are saved in bulk or not.
get_save_in_bulk = function() {
save_in_bulk <- private$.save_in_bulk
if (is.null(save_in_bulk)) {
# for experiments created before save_in_bulk was introduced
save_in_bulk <- c(fit = TRUE, eval = TRUE, viz = TRUE)
}
return(save_in_bulk)
},

#' @description Export all cached `Visualizer` results from an
#' `Experiment` to images in the `viz_results/` directory under the
#' `Experiment`'s results directory (see [get_save_dir()]).
Expand Down
120 changes: 102 additions & 18 deletions inst/rmd/results.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -541,17 +541,63 @@ show_recipe <- function(field_name = c(
#' Otherwise, the file is read in using data.table::fread().
#'
#' @param filename name of file (with file extension) to try reading in
#' @return output of filename if the file exists and NULL otherwise
get_results <- function(filename, filetype = ".rds") {
if (file.exists(filename)) {
if (filetype == ".rds") {
results <- readRDS(filename)
} else {
results <- data.table::fread(results)
#' @param filetype file extension
#' @param experiment experiment object
#' @param experiment_save_dir directory where experiment results are saved
#' @param field_name one of "evaluator" or "visualizer"
#' @return output of experiment results if the file exists and NULL otherwise
get_results <- function(filename, filetype = ".rds",
experiment, experiment_save_dir = NULL,
field_name = c("evaluator", "visualizer")) {
field_name <- match.arg(field_name)
if (field_name == "evaluator") {
save_in_bulk <- experiment$get_save_in_bulk()[["eval"]]
} else if (field_name == "visualizer") {
save_in_bulk <- experiment$get_save_in_bulk()[["viz"]]
}
results <- NULL
if (save_in_bulk) {
if (file.exists(filename)) {
if (filetype == ".rds") {
results <- readRDS(filename)
} else {
results <- data.table::fread(results)
}
}
} else {
results <- NULL
if (is.null(experiment_save_dir)) {
stop("experiment_save_dir must be provided if save_in_bulk is FALSE")
}
if (field_name == "evaluator") {
obj_names <- names(experiment$get_evaluators())
obj_dirname <- file.path(experiment_save_dir, "eval_results")
} else if (field_name == "visualizer") {
obj_names <- names(experiment$get_visualizers())
obj_dirname <- file.path(experiment_save_dir, "viz_results")
}
if (length(obj_names) == 0) {
return(NULL)
}
names(obj_names) <- obj_names
results <- purrr::map(
obj_names,
function(obj_name) {
obj_fname <- file.path(obj_dirname, sprintf("%s%s", obj_name, filetype))
if (file.exists(obj_fname)) {
if (filetype == ".rds") {
results <- readRDS(obj_fname)
} else {
results <- data.table::fread(obj_fname)
}
} else {
results <- NULL
}
return(results)
}
) |>
purrr::compact()
}
return(results)
}
Expand Down Expand Up @@ -596,39 +642,77 @@ get_exp_results <- function(dir_name,
eval_cache = ".rds", viz_cache = ".rds") {
exp_fname <- file.path(dir_name, "experiment.rds")
fit_fname <- file.path(dir_name, "fit_results.rds")
eval_fname <- file.path(dir_name, sprintf("eval_results%s", eval_cache))
viz_fname <- file.path(dir_name, sprintf("viz_results%s", viz_cache))
exp <- get_results(exp_fname)
if (file.exists(exp_fname)) {
exp <- readRDS(exp_fname)
} else {
results <- list(
exp = NULL,
eval_results = NULL,
viz_results = NULL
)
return(results)
}
fit_results <- NULL
eval_results <- NULL
viz_results <- NULL
if ((eval_cache != "none") && (viz_cache == ".rds")) {
if (show_eval) {
eval_results <- get_results(eval_fname, eval_cache)
eval_results <- get_results(
filename = eval_fname,
filetype = eval_cache,
experiment = exp,
experiment_save_dir = dir_name,
field_name = "evaluator"
)
}
if (show_viz) {
viz_results <- get_results(viz_fname, viz_cache)
viz_results <- get_results(
filename = viz_fname,
filetype = viz_cache,
experiment = exp,
experiment_save_dir = dir_name,
field_name = "visualizer"
)
}
} else {
if (show_eval) {
if (eval_cache == "none") {
fit_results <- get_results(fit_fname)
fit_results <- suppressMessages(get_cached_results(exp, "fit"))
if (is.null(fit_results)) {
stop("Cannot set eval_cache = 'none' since no cached fit results found. Perhaps try setting eval_cache = '.rds' instead.")
}
eval_results <- evaluate_experiment(exp, fit_results)
} else {
eval_results <- get_results(eval_fname, eval_cache)
eval_results <- get_results(
filename = eval_fname,
filetype = eval_cache,
experiment = exp,
experiment_save_dir = dir_name,
field_name = "evaluator"
)
}
}
if (show_viz) {
if (viz_cache == ".rds") {
viz_results <- get_results(viz_fname, viz_cache)
viz_results <- get_results(
filename = viz_fname,
filetype = viz_cache,
experiment = exp,
experiment_save_dir = dir_name,
field_name = "visualizer"
)
} else if (viz_cache == "none") {
if (is.null(fit_results)) {
fit_results <- get_results(fit_fname)
fit_results <- suppressMessages(get_cached_results(exp, "fit"))
if (is.null(fit_results)) {
stop("Cannot set viz_cache = 'none' since no cached fit results found. Perhaps try setting viz_cache = '.rds' instead.")
}
}
if (is.null(eval_results)) {
eval_results <- get_results(eval_fname)
eval_results <- suppressMessages(get_cached_results(exp, "eval"))
if (is.null(eval_results)) {
eval_results <- evaluate_experiment(exp, fit_results)
}
Expand Down Expand Up @@ -1044,7 +1128,7 @@ if (params$write) {
```{r evaluators, results = "asis"}
eval_recipe <- show_recipe(field_name = "evaluator", write_flag = params$write)
if (params$write) {
write_to_file(path = write_filename, "\n\n### Evaluation\n\n", eval_recipe)
write_to_file(path = write_filename, "\n\n## Evaluation\n\n", eval_recipe)
}
```

Expand Down
118 changes: 101 additions & 17 deletions inst/rmd/results_header_template.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,63 @@ get_descendants <- function(dir_name) {
#' Otherwise, the file is read in using data.table::fread().
#'
#' @param filename name of file (with file extension) to try reading in
#' @return output of filename if the file exists and NULL otherwise
get_results <- function(filename, filetype = ".rds") {
if (file.exists(filename)) {
if (filetype == ".rds") {
results <- readRDS(filename)
} else {
results <- data.table::fread(results)
#' @param filetype file extension
#' @param experiment experiment object
#' @param experiment_save_dir directory where experiment results are saved
#' @param field_name one of "evaluator" or "visualizer"
#' @return output of experiment results if the file exists and NULL otherwise
get_results <- function(filename, filetype = ".rds",
experiment, experiment_save_dir = NULL,
field_name = c("evaluator", "visualizer")) {
field_name <- match.arg(field_name)
if (field_name == "evaluator") {
save_in_bulk <- experiment$get_save_in_bulk()[["eval"]]
} else if (field_name == "visualizer") {
save_in_bulk <- experiment$get_save_in_bulk()[["viz"]]
}
results <- NULL
if (save_in_bulk) {
if (file.exists(filename)) {
if (filetype == ".rds") {
results <- readRDS(filename)
} else {
results <- data.table::fread(results)
}
}
} else {
results <- NULL
if (is.null(experiment_save_dir)) {
stop("experiment_save_dir must be provided if save_in_bulk is FALSE")
}
if (field_name == "evaluator") {
obj_names <- names(experiment$get_evaluators())
obj_dirname <- file.path(experiment_save_dir, "eval_results")
} else if (field_name == "visualizer") {
obj_names <- names(experiment$get_visualizers())
obj_dirname <- file.path(experiment_save_dir, "viz_results")
}
if (length(obj_names) == 0) {
return(NULL)
}
names(obj_names) <- obj_names
results <- purrr::map(
obj_names,
function(obj_name) {
obj_fname <- file.path(obj_dirname, sprintf("%s%s", obj_name, filetype))
if (file.exists(obj_fname)) {
if (filetype == ".rds") {
results <- readRDS(obj_fname)
} else {
results <- data.table::fread(obj_fname)
}
} else {
results <- NULL
}
return(results)
}
) |>
purrr::compact()
}
return(results)
}
Expand Down Expand Up @@ -132,39 +178,77 @@ get_exp_results <- function(dir_name,
eval_cache = ".rds", viz_cache = ".rds") {
exp_fname <- file.path(dir_name, "experiment.rds")
fit_fname <- file.path(dir_name, "fit_results.rds")
eval_fname <- file.path(dir_name, sprintf("eval_results%s", eval_cache))
viz_fname <- file.path(dir_name, sprintf("viz_results%s", viz_cache))
exp <- get_results(exp_fname)
if (file.exists(exp_fname)) {
exp <- readRDS(exp_fname)
} else {
results <- list(
exp = NULL,
eval_results = NULL,
viz_results = NULL
)
return(results)
}
fit_results <- NULL
eval_results <- NULL
viz_results <- NULL
if ((eval_cache != "none") && (viz_cache == ".rds")) {
if (show_eval) {
eval_results <- get_results(eval_fname, eval_cache)
eval_results <- get_results(
filename = eval_fname,
filetype = eval_cache,
experiment = exp,
experiment_save_dir = dir_name,
field_name = "evaluator"
)
}
if (show_viz) {
viz_results <- get_results(viz_fname, viz_cache)
viz_results <- get_results(
filename = viz_fname,
filetype = viz_cache,
experiment = exp,
experiment_save_dir = dir_name,
field_name = "visualizer"
)
}
} else {
if (show_eval) {
if (eval_cache == "none") {
fit_results <- get_results(fit_fname)
fit_results <- suppressMessages(get_cached_results(exp, "fit"))
if (is.null(fit_results)) {
stop("Cannot set eval_cache = 'none' since no cached fit results found. Perhaps try setting eval_cache = '.rds' instead.")
}
eval_results <- evaluate_experiment(exp, fit_results)
} else {
eval_results <- get_results(eval_fname, eval_cache)
eval_results <- get_results(
filename = eval_fname,
filetype = eval_cache,
experiment = exp,
experiment_save_dir = dir_name,
field_name = "evaluator"
)
}
}
if (show_viz) {
if (viz_cache == ".rds") {
viz_results <- get_results(viz_fname, viz_cache)
viz_results <- get_results(
filename = viz_fname,
filetype = viz_cache,
experiment = exp,
experiment_save_dir = dir_name,
field_name = "visualizer"
)
} else if (viz_cache == "none") {
if (is.null(fit_results)) {
fit_results <- get_results(fit_fname)
fit_results <- suppressMessages(get_cached_results(exp, "fit"))
if (is.null(fit_results)) {
stop("Cannot set viz_cache = 'none' since no cached fit results found. Perhaps try setting viz_cache = '.rds' instead.")
}
}
if (is.null(eval_results)) {
eval_results <- get_results(eval_fname)
eval_results <- suppressMessages(get_cached_results(exp, "eval"))
if (is.null(eval_results)) {
eval_results <- evaluate_experiment(exp, fit_results)
}
Expand Down
14 changes: 14 additions & 0 deletions man/Experiment.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion tests/testthat/test-docs.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ withr::with_tempdir(pattern = "simChef-test-checkpointing-temp", code = {
greatgrandchild2 <- create_experiment(
name = "greatgrandchild2",
clone_from = grandchild2,
save_dir = file.path(grandchild2$get_save_dir(), "greatgrandchild2")
save_dir = file.path(grandchild2$get_save_dir(), "greatgrandchild2"),
save_in_bulk = FALSE
)
results <- greatgrandchild2$run(save = TRUE, verbose = 0)
export_visualizers(greatgrandchild2)
Expand Down

0 comments on commit 5f7844b

Please sign in to comment.