Skip to content

Commit

Permalink
fix caching bug with function arguments (#198)
Browse files Browse the repository at this point in the history
* fix caching bug with function arguments
* change removeSource to utils::removeSource
* remove simplify_tibble from compute_reps() environment

Closes #168
  • Loading branch information
tiffanymtang authored Jan 6, 2025
1 parent 5f7844b commit d93bd30
Show file tree
Hide file tree
Showing 9 changed files with 217 additions and 36 deletions.
9 changes: 6 additions & 3 deletions R/experiment-utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -272,13 +272,15 @@ compute_rep <- function(reps,
param_df,
cached_results |>
dplyr::select(tidyselect::all_of(colnames(param_df))),
op = "contained_in"
op = "contained_in",
vary_param_names = c(vary_param_names, duplicate_param_names)
) &&
compare_tibble_rows(
param_df,
cached_fit_params |>
dplyr::select(tidyselect::all_of(colnames(param_df))),
op = "contained_in"
op = "contained_in",
vary_param_names = c(vary_param_names, duplicate_param_names)
)
if (is_cached) {
# if (verbose >= 1) {
Expand Down Expand Up @@ -386,7 +388,8 @@ compute_rep <- function(reps,

if (use_cached && file.exists(save_file) && !save_in_bulk) {
dgp_res <- get_matching_rows(
id = cached_fit_params, x = cached_results
id = cached_fit_params, x = cached_results,
vary_param_names = c(vary_param_names, duplicate_param_names)
) |>
dplyr::bind_rows(dgp_res)
}
Expand Down
50 changes: 29 additions & 21 deletions R/experiment.R
Original file line number Diff line number Diff line change
Expand Up @@ -318,17 +318,17 @@ Experiment <- R6::R6Class(
dplyr::mutate(
.dgp_name = purrr::map_chr(.dgp, ~.x$.dgp_name),
.dgp_fun = purrr::map(
.dgp, ~removeSource(dgp_list[[.x$.dgp_name]]$dgp_fun)
.dgp, ~ utils::removeSource(dgp_list[[.x$.dgp_name]]$dgp_fun)
),
.dgp_params = purrr::map(
.dgp, ~dgp_list[[.x$.dgp_name]]$dgp_params
.dgp, ~ dgp_list[[.x$.dgp_name]]$dgp_params
),
.method_name = purrr::map_chr(.method, ~.x$.method_name),
.method_fun = purrr::map(
.method, ~removeSource(method_list[[.x$.method_name]]$method_fun)
.method, ~ utils::removeSource(method_list[[.x$.method_name]]$method_fun)
),
.method_params = purrr::map(
.method, ~method_list[[.x$.method_name]]$method_params
.method, ~ method_list[[.x$.method_name]]$method_params
)
)

Expand Down Expand Up @@ -405,7 +405,7 @@ Experiment <- R6::R6Class(
obj_params <- tibble::tibble(
name = names(obj_list),
fun = purrr::map(
obj_list, ~removeSource(.x[[sprintf("%s_fun", field_name)]])
obj_list, ~ utils::removeSource(.x[[sprintf("%s_fun", field_name)]])
),
params = purrr::map(
obj_list, sprintf("%s_params", field_name)
Expand Down Expand Up @@ -1163,7 +1163,10 @@ Experiment <- R6::R6Class(
results <- private$.get_cached_results("fit", verbose = verbose)
fit_params <- private$.get_fit_params(wide_params = TRUE)

fit_results <- get_matching_rows(id = fit_params, x = results) |>
fit_results <- get_matching_rows(
id = fit_params, x = results,
vary_param_names = private$.get_vary_params()
) |>
dplyr::select(
.rep, .dgp_name, .method_name, private$.get_vary_params(),
tidyselect::everything()
Expand Down Expand Up @@ -1273,6 +1276,19 @@ Experiment <- R6::R6Class(

new_fit_results <- local({

do_call_wrapper <- function(name,
fun,
params,
verbose,
call) {
tryCatch(
do_call_handler(
name, fun, params, verbose, call
),
error = identity
)
}

# create an env with objs/funcs that the future workers need
workenv <- rlang::new_environment(
data = list(
Expand All @@ -1289,19 +1305,7 @@ Experiment <- R6::R6Class(
save_per_rep = save_per_rep,
use_cached = use_cached && (nrow(cached_fit_params) > 0),
save_dir = save_dir,
simplify_tibble = simplify_tibble,
do_call_wrapper = function(name,
fun,
params,
verbose,
call) {
tryCatch(
do_call_handler(
name, fun, params, verbose, call
),
error = identity
)
}
do_call_wrapper = do_call_wrapper
),
parent = rlang::ns_env()
)
Expand Down Expand Up @@ -1395,14 +1399,18 @@ Experiment <- R6::R6Class(
"fit", verbose = verbose
)
fit_results_cached <- get_matching_rows(
id = fit_params_cached, x = fit_results_cached
id = fit_params_cached, x = fit_results_cached,
vary_param_names = private$.get_vary_params()
)
if (verbose >= 1) {
inform("Appending cached results to the new fit results...")
}
fit_params <- private$.get_fit_params(wide_params = TRUE)
fit_results <- dplyr::bind_rows(fit_results, fit_results_cached)
fit_results <- get_matching_rows(id = fit_params, x = fit_results) |>
fit_results <- get_matching_rows(
id = fit_params, x = fit_results,
vary_param_names = private$.get_vary_params()
) |>
dplyr::select(
.rep, .dgp_name, .method_name, private$.get_vary_params(),
tidyselect::everything()
Expand Down
74 changes: 67 additions & 7 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -312,13 +312,16 @@ fix_duplicate_param_names <- function(dgp_params, method_params,
#' @param x A tibble with unique rows.
#' @param y A tibble with unique rows.
#' @param op Name of opertaion.
#' @param vary_param_names Character vector of parameter names that are varied
#' across in the Experiment.
#'
#' @return If `op == "equal"`, returns a boolean indicating if `x` and
#' `y` have the same rows, ignoring the row order. If
#' `op == "contained_in"`, returns a boolean indicating if all rows in
#' `x` are contained in the rows of `y`.
#' @keywords internal
compare_tibble_rows <- function(x, y, op = c("equal", "contained_in")) {
compare_tibble_rows <- function(x, y, op = c("equal", "contained_in"),
vary_param_names = NULL) {
op <- match.arg(op)
if ((!tibble::is_tibble(x)) || (!tibble::is_tibble(y))) {
abort("x and y must be tibbles.")
Expand All @@ -335,7 +338,11 @@ compare_tibble_rows <- function(x, y, op = c("equal", "contained_in")) {
return(FALSE)
}
}
duplicated_rows <- rbind(x, y) |>
# manually remove source code from functions
duplicated_rows <- rbind(
remove_source(x, cols = vary_param_names),
remove_source(y, cols = vary_param_names)
) |>
duplicated(fromLast = TRUE)
return(all(duplicated_rows[1:nrow(x)]))
}
Expand All @@ -352,13 +359,14 @@ compare_tibble_rows <- function(x, y, op = c("equal", "contained_in")) {
#' distinct rows while `dplyr::inner_join` does not. This function
#' enables caching when functions are used as parameters in DGPs and Methods.
#'
#' @inheritParams compare_tibble_rows
#' @param id A tibble with distinct rows.
#' @param x A tibble.
#'
#' @return A tibble, containing the subset of rows from `x` that match
#' id rows from `id`.
#' @keywords internal
get_matching_rows <- function(id, x) {
get_matching_rows <- function(id, x, vary_param_names = NULL) {
if ((!tibble::is_tibble(id)) || (!tibble::is_tibble(x))) {
abort("id and x must be tibbles.")
}
Expand All @@ -371,13 +379,17 @@ get_matching_rows <- function(id, x) {
stop("id must be a tibble with unique rows.")
}
x_ids <- x |>
dplyr::select(tidyselect::all_of(id_cols))
dplyr::select(tidyselect::all_of(id_cols)) |>
remove_source(cols = vary_param_names)
if (!any(id_coltypes == "list")) {
# easy case: no functions in id tibble -> use inner_join
# easy case: definitely no functions in id tibble -> use inner_join
out <- dplyr::inner_join(id, x, by = id_cols)
} else if (!anyDuplicated(x_ids)) {
# no duplicate id rows in x -> use duplicated
df <- dplyr::bind_rows(id, x)
df <- dplyr::bind_rows(
remove_source(id, cols = vary_param_names),
remove_source(x, cols = vary_param_names)
)
keep_row_idx <- df |>
dplyr::select(tidyselect::all_of(id_cols)) |>
duplicated()
Expand All @@ -388,7 +400,10 @@ get_matching_rows <- function(id, x) {
keep_row_idx <- purrr::map_lgl(
1:nrow(x),
function(i) {
dplyr::bind_rows(id, x_ids[i, ]) |>
dplyr::bind_rows(
remove_source(id, cols = vary_param_names),
remove_source(x_ids[i, ], cols = vary_param_names)
) |>
duplicated() |>
dplyr::last()
}
Expand Down Expand Up @@ -542,3 +557,48 @@ HTML <- function(text, ..., .noWS = NULL) {
class(htmlText) <- c("html", "character")
return(htmlText)
}


#' Remove source code from functions.
#'
#' @description Remove the source code from functions so that caching works as
#' expected. This function is used internally to remove the source code from
#' functions in the `.dgp_params`, `.method_params`, and `vary_across`
#' components.
#'
#' @param x A tibble.
#' @param cols Character vector of column names to potentially remove source
#' code from.
#'
#' @returns A tibble with the source code removed from functions in the
#' specified columns.
#' @keywords internal
remove_source <- function(x, cols = NULL) {
for (func_col in c(".dgp_params", ".method_params", ".eval_params", ".viz_params")) {
if (func_col %in% names(x)) {
x[[func_col]] <- purrr::map(
x[[func_col]],
function(params_ls) {
purrr::map(
params_ls,
function(.x) if (is.function(.x)) utils::removeSource(.x) else .x
)
}
)
}
}

if (!is.null(cols)) {
for (func_col in cols) {
if (func_col %in% names(x)) {
if (is.list(x[[func_col]])) {
x[[func_col]] <- purrr::map(
x[[func_col]],
function(.x) if (is.function(.x)) utils::removeSource(.x) else .x
)
}
}
}
}
return(x)
}
10 changes: 9 additions & 1 deletion man/compare_tibble_rows.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion man/get_matching_rows.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 25 additions & 0 deletions man/remove_source.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit d93bd30

Please sign in to comment.