Skip to content

Commit

Permalink
rephrase "inner split"
Browse files Browse the repository at this point in the history
* `.should_inner_split()` -> `.workflow_includes_calibration()`
* refactor conditional in `fit.workflow()` into two
  • Loading branch information
simonpcouch committed Sep 30, 2024
1 parent 0a38a86 commit b28a6c4
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 22 deletions.
2 changes: 1 addition & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ export(.fit_finalize)
export(.fit_model)
export(.fit_post)
export(.fit_pre)
export(.should_inner_split)
export(.workflow_includes_calibration)
export(add_case_weights)
export(add_formula)
export(add_model)
Expand Down
16 changes: 10 additions & 6 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,15 @@ fit.workflow <- function(object, data, ..., calibration = NULL, control = contro
workflow <- object
workflow <- .fit_pre(workflow, data)
workflow <- .fit_model(workflow, control)

if (!.workflow_includes_calibration(workflow)) {
# in this case, training the tailor on `data` will not leak data (#262)
calibration <- data
}
if (has_postprocessor(workflow)) {
# if (is.null(calibration)), then the tailor doesn't have a calibrator
# and training the tailor on `data` will not leak data
workflow <- .fit_post(workflow, calibration %||% data)
workflow <- .fit_post(workflow, calibration)
}

workflow <- .fit_finalize(workflow)

workflow
Expand All @@ -84,7 +88,7 @@ fit.workflow <- function(object, data, ..., calibration = NULL, control = contro
#' @export
#' @rdname workflows-internals
#' @keywords internal
.should_inner_split <- function(workflow) {
.workflow_includes_calibration <- function(workflow) {
has_postprocessor(workflow) &&
tailor::tailor_requires_fit(
extract_postprocessor(workflow, estimated = FALSE)
Expand Down Expand Up @@ -227,15 +231,15 @@ validate_has_model <- function(x, ..., call = caller_env()) {
}

validate_has_calibration <- function(x, calibration, call = caller_env()) {
if (.should_inner_split(x) && is.null(calibration)) {
if (.workflow_includes_calibration(x) && is.null(calibration)) {
cli::cli_abort(
"The workflow requires a {.arg calibration} set to train but none
was supplied.",
call = call
)
}

if (!.should_inner_split(x) && !is.null(calibration)) {
if (!.workflow_includes_calibration(x) && !is.null(calibration)) {
cli::cli_warn(
"The workflow does not require a {.arg calibration} set to train
but one was supplied.",
Expand Down
2 changes: 1 addition & 1 deletion R/post-action-tailor.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
#' which then form the training data for the postprocessor.
#'
#' When fitting a workflow with a postprocessor that requires training
#' (i.e. one that returns `TRUE` in `.should_inner_split(workflow)`), users
#' (i.e. one that returns `TRUE` in `.workflow_includes_calibration(workflow)`), users
#' must pass two data arguments--the usual `fit.workflow(data)` will be used
#' to train the preprocessor and model while `fit.workflow(calibration)` will
#' be used to train the postprocessor.
Expand Down
2 changes: 1 addition & 1 deletion man/add_tailor.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions man/workflows-internals.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 10 additions & 10 deletions tests/testthat/test-fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -197,39 +197,39 @@ test_that("`.fit_pre()` doesn't modify user supplied recipe blueprint", {

# ------------------------------------------------------------------------------
# .fit_post()
test_that(".should_inner_split works", {
test_that(".workflow_includes_calibration works", {
skip_if_not_installed("tailor")

expect_false(.should_inner_split(workflow()))
expect_false(.should_inner_split(workflow() %>% add_model(parsnip::linear_reg())))
expect_false(.should_inner_split(workflow() %>% add_formula(mpg ~ .)))
expect_false(.should_inner_split(
expect_false(.workflow_includes_calibration(workflow()))
expect_false(.workflow_includes_calibration(workflow() %>% add_model(parsnip::linear_reg())))
expect_false(.workflow_includes_calibration(workflow() %>% add_formula(mpg ~ .)))
expect_false(.workflow_includes_calibration(
workflow() %>%
add_formula(mpg ~ .) %>%
add_model(parsnip::linear_reg())
))
expect_false(.should_inner_split(
expect_false(.workflow_includes_calibration(
workflow() %>%
add_tailor(tailor::tailor())
))
expect_false(.should_inner_split(
expect_false(.workflow_includes_calibration(
workflow() %>%
add_tailor(tailor::tailor() %>% tailor::adjust_probability_threshold(.4))
))

expect_true(.should_inner_split(
expect_true(.workflow_includes_calibration(
workflow() %>%
add_tailor(tailor::tailor() %>% tailor::adjust_numeric_calibration())
))
expect_true(.should_inner_split(
expect_true(.workflow_includes_calibration(
workflow() %>%
add_tailor(
tailor::tailor() %>%
tailor::adjust_numeric_calibration() %>%
tailor::adjust_numeric_range(lower_limit = 1)
)
))
expect_true(.should_inner_split(
expect_true(.workflow_includes_calibration(
workflow() %>%
add_formula(mpg ~ .) %>%
add_model(parsnip::linear_reg()) %>%
Expand Down

0 comments on commit b28a6c4

Please sign in to comment.