Skip to content

Commit

Permalink
require optional arguments to be named (#863)
Browse files Browse the repository at this point in the history
* move dots in `augment.tune_results()`

* move dots in `collect_*()` functions

* move dots in `compute_metrics()`

* move dots in developer-focused functions

* move dots in `autoplot.tune_results()`

* add dots to relevant developer-focused functions

* move dots in `*_best()`

* add dots to `conf_mat_resampled()`

* check dots are empty in functions that newly have them

* move dots in `first_eval_time()`

* check existing but newly moved dots are empty

* move dots in `fit_best()`

* revert "move dots in `autoplot.tune_results()`"

Those dots are actually passed on to internal functions and aren't just to enforce naming arguments.

* add NEWS entry

* correct ref, note exception

* re`document()`

* name `collect_predictions()` argument

* write out ad-hoc dots check
  • Loading branch information
simonpcouch authored Mar 1, 2024
1 parent 8d621ca commit 884cc7a
Show file tree
Hide file tree
Showing 28 changed files with 123 additions and 165 deletions.
5 changes: 3 additions & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,11 @@

* For iterative optimization routines, `autoplot()` will use integer breaks when `type = "performance"` or `type = "parameters"`.

## Breaking Change
## Breaking Changes

* Several functions gain an `eval_time` argument for the evaluation time of dynamic metrics for censored regression. The placement of the argument breaks passing-by-position for one or more other arguments to `fit_best.tune_results()`, `show_best.tune_results()`, and the developer-focused `check_initial()` (#857).
* Several functions gained an `eval_time` argument for the evaluation time of dynamic metrics for censored regression. The placement of the argument breaks passing-by-position for one or more other arguments to `autoplot.tune_results()` and the developer-focused `check_initial()` (#857).

* Ellipses (...) are now used consistently in the package to require optional arguments to be named. For functions that previously had ellipses at the end of the function signature, they have been moved to follow the last argument without a default value: this applies to `augment.tune_results()`, `collect_predictions.tune_results()`, `collect_metrics.tune_results()`, and the developer-focused `estimate_tune_results()`, `load_pkgs()`, and `encode_set()`. Several other functions that previously did not have ellipses in their signatures gained them: this applies to `conf_mat_resampled()` and the developer-focused `check_workflow()`. Optional arguments previously passed by position will now error informatively prompting them to be named. These changes don't apply in cases when the ellipses are currently in use to forward arguments to other functions (#863).

# tune 1.1.2

Expand Down
45 changes: 9 additions & 36 deletions R/augment.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
#' @param x An object resulting from one of the `tune_*()` functions,
#' `fit_resamples()`, or `last_fit()`. The control specifications for these
#' objects should have used the option `save_pred = TRUE`.
#' @param ... Not currently used.
#' @param parameters A data frame with a single row that indicates what
#' tuning parameters should be used to generate the predictions (for `tune_*()`
#' objects only). If `NULL`, `select_best(x)` will be used with the first
#' metric and, if applicable, the first evaluation time point, used to
#' objects only). If `NULL`, `select_best(x)` will be used with the first
#' metric and, if applicable, the first evaluation time point, used to
#' create `x`.
#' @param ... Not currently used.
#' @return A data frame with one or more additional columns for model
#' predictions.
#'
Expand All @@ -34,24 +34,15 @@
#' results.
#'
#' @export
augment.tune_results <- function(x, parameters = NULL, ...) {
dots <- rlang::list2(...)
if (length(dots) > 0) {
rlang::abort(
paste(
"The only two arguments for `augment.tune_results()` are",
"'x' and 'parameters'. Others were passed:",
paste0("'", names(dots), "'", collapse = ", ")
)
)
}
augment.tune_results <- function(x, ..., parameters = NULL) {
rlang::check_dots_empty()

# check/determine best settings
if (is.null(parameters)) {
obj_fun <- .get_tune_metric_names(x)[1]
obj_eval_time <- choose_eval_time(
x,
metric = obj_fun,
x,
metric = obj_fun,
eval_time = NULL,
quietly = TRUE
)
Expand All @@ -70,16 +61,7 @@ augment.tune_results <- function(x, parameters = NULL, ...) {
#' @rdname augment.tune_results
#' @export
augment.resample_results <- function(x, ...) {
dots <- rlang::list2(...)
if (length(dots) > 0) {
rlang::abort(
paste(
"The only argument for `augment.fit_resamples()` is",
"'x'. Others were passed:",
paste0("'", names(dots), "'", collapse = ", ")
)
)
}
rlang::check_dots_empty()

pred <- collect_predictions(x, summarize = TRUE)
y_nm <- .get_tune_outcome_names(x)
Expand All @@ -91,16 +73,7 @@ augment.resample_results <- function(x, ...) {
#' @rdname augment.tune_results
#' @export
augment.last_fit <- function(x, ...) {
dots <- rlang::list2(...)
if (length(dots) > 0) {
rlang::abort(
paste(
"The only argument for `augment.last_fit()` is",
"'x'. Others were passed:",
paste0("'", names(dots), "'", collapse = ", ")
)
)
}
rlang::check_dots_empty()

pred <- collect_predictions(x, summarize = TRUE)
pred$.row <- 1:nrow(pred)
Expand Down
4 changes: 3 additions & 1 deletion R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ check_param_objects <- function(pset) {
#' @keywords internal
#' @rdname empty_ellipses
#' @param check_dials A logical for check for a NULL parameter object.
check_workflow <- function(x, pset = NULL, check_dials = FALSE, call = caller_env()) {
check_workflow <- function(x, ..., pset = NULL, check_dials = FALSE, call = caller_env()) {
if (!inherits(x, "workflow")) {
rlang::abort("The `object` argument should be a 'workflow' object.")
}
Expand All @@ -288,6 +288,8 @@ check_workflow <- function(x, pset = NULL, check_dials = FALSE, call = caller_en
rlang::abort("A parsnip model is required.")
}

rlang::check_dots_empty(call = call)

if (check_dials) {
if (is.null(pset)) {
pset <- hardhat::extract_parameter_set_dials(x)
Expand Down
18 changes: 13 additions & 5 deletions R/collect.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#' @param x The results of [tune_grid()], [tune_bayes()], [fit_resamples()],
#' or [last_fit()]. For [collect_predictions()], the control option `save_pred
#' = TRUE` should have been used.
#' @param ... Not currently used.
#' @param summarize A logical; should metrics be summarized over resamples
#' (`TRUE`) or return the values for each individual resample. Note that, if `x`
#' is created by [last_fit()], `summarize` has no effect. For the other object
Expand All @@ -17,7 +18,6 @@
#' each metric has its own column and the `n` and `std_err` columns are removed,
#' if they exist.
#'
#' @param ... Not currently used.
#' @return A tibble. The column names depend on the results and the mode of the
#' model.
#'
Expand Down Expand Up @@ -120,7 +120,11 @@
#'
#' collect_predictions(resampled) %>% arrange(.row)
#' collect_predictions(resampled, summarize = TRUE) %>% arrange(.row)
#' collect_predictions(resampled, summarize = TRUE, grid[1, ]) %>% arrange(.row)
#' collect_predictions(
#' resampled,
#' summarize = TRUE,
#' parameters = grid[1, ]
#' ) %>% arrange(.row)
#'
#' collect_extracts(resampled)
#'
Expand All @@ -139,7 +143,9 @@ collect_predictions.default <- function(x, ...) {

#' @export
#' @rdname collect_predictions
collect_predictions.tune_results <- function(x, summarize = FALSE, parameters = NULL, ...) {
collect_predictions.tune_results <- function(x, ..., summarize = FALSE, parameters = NULL) {
rlang::check_dots_empty()

names <- colnames(x)
coll_col <- ".predictions"

Expand Down Expand Up @@ -454,7 +460,8 @@ collect_metrics.default <- function(x, ...) {

#' @export
#' @rdname collect_predictions
collect_metrics.tune_results <- function(x, summarize = TRUE, type = c("long", "wide"), ...) {
collect_metrics.tune_results <- function(x, ..., summarize = TRUE, type = c("long", "wide")) {
rlang::check_dots_empty()
rlang::arg_match0(type, values = c("long", "wide"))

if (inherits(x, "last_fit")) {
Expand Down Expand Up @@ -551,7 +558,8 @@ collector <- function(x, coll_col = ".predictions") {
#' @export
#' @keywords internal
#' @rdname empty_ellipses
estimate_tune_results <- function(x, col_name = ".metrics", ...) {
estimate_tune_results <- function(x, ..., col_name = ".metrics") {
rlang::check_dots_empty()
param_names <- .get_tune_parameter_names(x)
id_names <- grep("^id", names(x), value = TRUE)
group_cols <- .get_extra_col_names(x)
Expand Down
5 changes: 3 additions & 2 deletions R/compute_metrics.R
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,10 @@ compute_metrics.default <- function(x,
#' @rdname compute_metrics
compute_metrics.tune_results <- function(x,
metrics,
...,
summarize = TRUE,
event_level = "first",
...) {
event_level = "first") {
rlang::check_dots_empty()
if (!".predictions" %in% names(x)) {
rlang::abort(paste0(
"`x` must have been generated with the ",
Expand Down
4 changes: 3 additions & 1 deletion R/conf_mat_resampled.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#'
#' @param x An object with class `tune_results` that was used with a
#' classification model that was run with `control_*(save_pred = TRUE)`.
#' @param ... Currently unused, must be empty.
#' @param parameters A tibble with a single tuning parameter combination. Only
#' one tuning parameter combination (if any were used) is allowed here.
#' @param tidy Should the results come back in a tibble (`TRUE`) or a `conf_mat`
Expand All @@ -30,7 +31,8 @@
#' conf_mat_resampled(res)
#' conf_mat_resampled(res, tidy = FALSE)
#' @export
conf_mat_resampled <- function(x, parameters = NULL, tidy = TRUE) {
conf_mat_resampled <- function(x, ..., parameters = NULL, tidy = TRUE) {
rlang::check_dots_empty()
if (!inherits(x, "tune_results")) {
rlang::abort(
"The first argument needs to be an object with class 'tune_results'."
Expand Down
10 changes: 4 additions & 6 deletions R/fit_best.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#' @param x The results of class `tune_results` (coming from functions such as
#' [tune_grid()], [tune_bayes()], etc). The control option
#' [`save_workflow = TRUE`][tune::control_grid] should have been used.
#' @param ... Not currently used, must be empty.
#' @param metric A character string (or `NULL`) for which metric to optimize. If
#' `NULL`, the first metric is used.
#' @param parameters An optional 1-row tibble of tuning parameter settings, with
Expand All @@ -22,7 +23,6 @@
#' `NULL`, the validation set is not used for resamples originating from
#' [rsample::validation_set()] while it is used for resamples originating
#' from [rsample::validation_split()].
#' @param ... Not currently used.
#' @inheritParams select_best
#' @details
#' This function is a shortcut for the manual steps of:
Expand Down Expand Up @@ -88,15 +88,13 @@ fit_best.default <- function(x, ...) {
#' @export
#' @rdname fit_best
fit_best.tune_results <- function(x,
...,
metric = NULL,
eval_time = NULL,
parameters = NULL,
verbose = FALSE,
add_validation_set = NULL,
...) {
if (length(list(...))) {
cli::cli_abort(c("x" = "The `...` are not used by this function."))
}
add_validation_set = NULL) {
rlang::check_dots_empty()
wflow <- .get_tune_workflow(x)
if (is.null(wflow)) {
cli::cli_abort(c("x" = "The control option `save_workflow = TRUE` should be used when tuning."))
Expand Down
9 changes: 6 additions & 3 deletions R/load_ns.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,19 @@
#' @return An invisible NULL.
#' @keywords internal
#' @export
load_pkgs <- function(x, infra = TRUE, ...) {
load_pkgs <- function(x, ..., infra = TRUE) {
UseMethod("load_pkgs")
}

#' @export
load_pkgs.character <- function(x, ...) {
rlang::check_dots_empty()
withr::with_preserve_seed(.load_namespace(x))
}

#' @export
load_pkgs.model_spec <- function(x, infra = TRUE, ...) {
load_pkgs.model_spec <- function(x, ..., infra = TRUE) {
rlang::check_dots_empty()
pkgs <- required_pkgs(x)
if (infra) {
pkgs <- c(infra_pkgs, pkgs)
Expand All @@ -27,7 +29,8 @@ load_pkgs.model_spec <- function(x, infra = TRUE, ...) {
}

#' @export
load_pkgs.workflow <- function(x, infra = TRUE, ...) {
load_pkgs.workflow <- function(x, ..., infra = TRUE) {
rlang::check_dots_empty()
load_pkgs.model_spec(extract_spec_parsnip(x), infra = infra)
}

Expand Down
15 changes: 9 additions & 6 deletions R/metric-selection.R
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ check_mult_metrics <- function(metric, ..., call = rlang::caller_env()) {

#' @rdname choose_metric
#' @export
check_metric_in_tune_results <- function(mtr_info, metric, call = rlang::caller_env()) {
check_metric_in_tune_results <- function(mtr_info, metric, ..., call = rlang::caller_env()) {
rlang::check_dots_empty(call = call)
if (!any(mtr_info$metric == metric)) {
cli::cli_abort("{.val {metric}} was not in the metric set. Please choose
from: {.val {mtr_info$metric}}.", call = call)
Expand Down Expand Up @@ -97,8 +98,8 @@ contains_survival_metric <- function(mtr_info) {
# choose_eval_time() is called by show_best(), select_best(), and augment()
#' @rdname choose_metric
#' @export
choose_eval_time <- function(x, metric, eval_time = NULL, quietly = FALSE, call = rlang::caller_env()) {

choose_eval_time <- function(x, metric, ..., eval_time = NULL, quietly = FALSE, call = rlang::caller_env()) {
rlang::check_dots_empty(call = call)
mtr_set <- .get_tune_metrics(x)
mtr_info <- tibble::as_tibble(mtr_set)

Expand Down Expand Up @@ -183,7 +184,7 @@ first_metric <- function(mtr_set) {
# such as tune_bayes().
#' @rdname choose_metric
#' @export
first_eval_time <- function(mtr_set, metric = NULL, eval_time = NULL, ..., quietly = FALSE, call = rlang::caller_env()) {
first_eval_time <- function(mtr_set, ..., metric = NULL, eval_time = NULL, quietly = FALSE, call = rlang::caller_env()) {
rlang::check_dots_empty()

num_times <- length(eval_time)
Expand Down Expand Up @@ -253,7 +254,8 @@ first_eval_time <- function(mtr_set, metric = NULL, eval_time = NULL, ..., quiet

#' @rdname choose_metric
#' @export
check_metrics_arg <- function(mtr_set, wflow, call = rlang::caller_env()) {
check_metrics_arg <- function(mtr_set, wflow, ..., call = rlang::caller_env()) {
rlang::check_dots_empty(call = call)
mode <- extract_spec_parsnip(wflow)$mode

if (is.null(mtr_set)) {
Expand Down Expand Up @@ -308,7 +310,8 @@ check_metrics_arg <- function(mtr_set, wflow, call = rlang::caller_env()) {

#' @rdname choose_metric
#' @export
check_eval_time_arg <- function(eval_time, mtr_set, call = rlang::caller_env()) {
check_eval_time_arg <- function(eval_time, mtr_set, ..., call = rlang::caller_env()) {
rlang::check_dots_empty(call = call)
mtr_info <- tibble::as_tibble(mtr_set)

# Not a survival metric
Expand Down
16 changes: 8 additions & 8 deletions R/select_best.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
#' performance is within some acceptable limit.
#'
#' @param x The results of [tune_grid()] or [tune_bayes()].
#' @param ... For [select_by_one_std_err()] and [select_by_pct_loss()], this
#' argument is passed directly to [dplyr::arrange()] so that the user can sort
#' the models from *most simple to most complex*. That is, for a parameter `p`,
#' pass the unquoted expression `p` if smaller values of `p` indicate a simpler
#' model, or `desc(p)` if larger values indicate a simpler model. At
#' least one term is required for these two functions. See the examples below.
#' @param metric A character value for the metric that will be used to sort
#' the models. (See
#' \url{https://yardstick.tidymodels.org/articles/metric-types.html} for
Expand All @@ -24,12 +30,6 @@
#' @param n An integer for the number of top results/rows to return.
#' @param limit The limit of loss of performance that is acceptable (in percent
#' units). See details below.
#' @param ... For [select_by_one_std_err()] and [select_by_pct_loss()], this
#' argument is passed directly to [dplyr::arrange()] so that the user can sort
#' the models from *most simple to most complex*. That is, for a parameter `p`,
#' pass the unquoted expression `p` if smaller values of `p` indicate a simpler
#' model, or `desc(p)` if larger values indicate a simpler model. At
#' least one term is required for these two functions. See the examples below.
#' @param eval_time A single numeric time point where dynamic event time
#' metrics should be chosen (e.g., the time-dependent ROC curve, etc). The
#' values should be consistent with the values used to create `x`. The `NULL`
Expand Down Expand Up @@ -78,10 +78,10 @@ show_best.default <- function(x, ...) {
#' @export
#' @rdname show_best
show_best.tune_results <- function(x,
...,
metric = NULL,
eval_time = NULL,
n = 5,
...,
call = rlang::current_env()) {
rlang::check_dots_empty()

Expand Down Expand Up @@ -119,7 +119,7 @@ select_best.default <- function(x, ...) {

#' @export
#' @rdname show_best
select_best.tune_results <- function(x, metric = NULL, eval_time = NULL, ...) {
select_best.tune_results <- function(x, ..., metric = NULL, eval_time = NULL) {
rlang::check_dots_empty()

metric_info <- choose_metric(x, metric)
Expand Down
5 changes: 3 additions & 2 deletions R/tune_bayes.R
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ tune_bayes_workflow <- function(object,
maximize <- opt_metric$direction == "maximize"

eval_time <- check_eval_time_arg(eval_time, metrics, call = call)
opt_metric_time <- first_eval_time(metrics, opt_metric_name, eval_time, call = call)
opt_metric_time <- first_eval_time(metrics, metric = opt_metric_name, eval_time = eval_time, call = call)

if (is.null(param_info)) {
param_info <- hardhat::extract_parameter_set_dials(object)
Expand Down Expand Up @@ -547,7 +547,8 @@ check_iter <- function(iter, call) {
#' @rdname empty_ellipses
#' @param pset A `parameters` object.
#' @param as_matrix A logical for the return type.
encode_set <- function(x, pset, as_matrix = FALSE, ...) {
encode_set <- function(x, pset, ..., as_matrix = FALSE) {
rlang::check_dots_empty()
# change the numeric variables to the transformed scale (if any)
has_trans <- purrr::map_lgl(pset$object, ~ !is.null(.x$trans))
if (any(has_trans)) {
Expand Down
Loading

0 comments on commit 884cc7a

Please sign in to comment.