Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

resample calibration post-processors with an internal split #894

Merged
merged 19 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Depends:
R (>= 4.0)
Imports:
cli (>= 3.3.0),
container,
dials (>= 1.0.0),
doFuture (>= 1.0.0),
dplyr (>= 1.1.0),
Expand Down Expand Up @@ -54,6 +55,10 @@ Suggests:
testthat (>= 3.0.0),
xgboost,
xml2
Remotes:
tidymodels/container#12,
tidymodels/workflows#225,
tidymodels/hardhat
Config/Needs/website: pkgdown, tidymodels, kknn, doParallel, doFuture,
tidyverse/tidytemplate
Config/testthat/edition: 3
Expand Down
78 changes: 71 additions & 7 deletions R/grid_code_paths.R
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,10 @@ tune_grid_loop_iter <- function(split,
seed,
metrics_info = metrics_info(metrics),
params) {
# `split` may be overwritten later on to create an "internal" split for
# post-processing. however, we want the original split to persist so we can
# use it (particularly `labels(split_orig)`) in logging
split_orig <- split

load_pkgs(workflow)
.load_namespace(control$pkgs)
Expand Down Expand Up @@ -373,6 +377,25 @@ tune_grid_loop_iter <- function(split,

training <- rsample::analysis(split)
Copy link
Member

@topepo topepo Apr 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should change this name just on principle


if (should_internal_split(workflow)) {
# if the workflow has a postprocessor that needs training (i.e. calibration),
# further split the analysis data into an "internal" analysis and
# assessment set.
# * the preprocessor and model (excluding the post-processor) are fitted
# on `analysis(split_post)`, the internal analysis set
# * that model generates predictions on `assessment(split_post)`, the
# internal assessment set
# * the post-processor is trained on the predictions generated from the
# internal assessment set
# * the model (including the post-processor) generates predictions on the
# assessment set (not internal, i.e. `assessment(split)`) and those
# predictions are assessed with performance metrics
split <- rsample::initial_split(training)

This comment was marked as outdated.

# todo: this should have a better name (analysis?) -- needs to be
# `training` right now to align with the `training` above
training <- rsample::analysis(split)
simonpcouch marked this conversation as resolved.
Show resolved Hide resolved
}

# ----------------------------------------------------------------------------
# Preprocessor loop

Expand Down Expand Up @@ -400,7 +423,7 @@ tune_grid_loop_iter <- function(split,
workflow <- .catch_and_log(
.expr = .fit_pre(workflow, training),
control,
split,
split_orig,
iter_msg_preprocessor,
notes = out_notes
)
Expand Down Expand Up @@ -435,7 +458,7 @@ tune_grid_loop_iter <- function(split,
workflow <- .catch_and_log_fit(
.expr = .fit_model(workflow, control_workflow),
control,
split,
split_orig,
iter_msg_model,
notes = out_notes
)
Expand All @@ -460,15 +483,19 @@ tune_grid_loop_iter <- function(split,
iter_grid_model
)

# to-do: this currently doesn't include the trained post-processor.
# we could either `if (!should_internal_split())` here and the opposite
# condition later OR just extract later than we used to (possibly meaning
# that failing to predict means no extracts).
elt_extract <- .catch_and_log(
extract_details(workflow, control$extract),
control,
split,
split_orig,
paste(iter_msg_model, "(extracts)"),
bad_only = TRUE,
notes = out_notes
)
elt_extract <- make_extracts(elt_extract, iter_grid, split, .config = iter_config)
elt_extract <- make_extracts(elt_extract, iter_grid, split_orig, .config = iter_config)
out_extracts <- append_extracts(out_extracts, elt_extract)

iter_msg_predictions <- paste(iter_msg_model, "(predictions)")
Expand All @@ -477,7 +504,7 @@ tune_grid_loop_iter <- function(split,
predict_model(split, workflow, iter_grid, metrics, iter_submodels,
metrics_info = metrics_info, eval_time = eval_time),
control,
split,
split_orig,
iter_msg_predictions,
bad_only = TRUE,
notes = out_notes
Expand All @@ -488,14 +515,51 @@ tune_grid_loop_iter <- function(split,
next
}

if (should_internal_split(workflow)) {
# note that, since we're training a postprocessor, `iter_predictions`
# are the predictions from the internal assessment set rather than the
# assessment set (i.e. `assessment(split_orig)`)

# train the post-processor on the predictions generated from the model
# on the internal assessment set
# todo: this is the same assessment set that `predict_model` makes.
# we're ad-hoc `augment()`ing here, but would be nice to just have
# those predictors
# todo: needs a `.catch_and_log`
# todo: .fit_post currently takes in `assessment(split)` rather than
# a set of predictions, meaning that we predict on `assessment(split)`
# twice :(
internal_assessment <- assessment(split)
workflow_with_post <-
.fit_post(workflow, dplyr::bind_cols(assessment(split)))

workflow_with_post <- .fit_finalize(workflow_with_post)

# generate predictions on the assessment set (not internal,
# i.e. `assessment(split_orig)`) from the model and apply the
# post-processor to those predictions to generate updated predictions
iter_predictions <- .catch_and_log(
predict_model(split_orig, workflow_with_post, iter_grid, metrics,
iter_submodels, metrics_info = metrics_info,
eval_time = eval_time),
control,
split_orig,
paste(iter_msg_model, "(predictions with post-processor)"),
bad_only = TRUE,
notes = out_notes
)

# now, assess those predictions with performance metrics
}

out_metrics <- append_metrics(
collection = out_metrics,
predictions = iter_predictions,
metrics = metrics,
param_names = param_names,
outcome_name = outcome_names,
event_level = event_level,
split = split,
split = split_orig,
.config = iter_config,
metrics_info = metrics_info
)
Expand All @@ -505,7 +569,7 @@ tune_grid_loop_iter <- function(split,
out_predictions <- append_predictions(
collection = out_predictions,
predictions = iter_predictions,
split = split,
split = split_orig,
control = control,
.config = iter_config_metrics
)
Expand Down
23 changes: 23 additions & 0 deletions R/grid_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,25 @@ compute_config_ids <- function(data, id_preprocessor) {
out
}

should_internal_split <- function(workflow) {
has_postprocessor(workflow) && postprocessor_requires_training(workflow)
}

postprocessor_requires_training <- function(workflow) {
# todo: `extract_postprocessor(workflow)` would fail here
container <- workflow$post$actions$container$container

operations_are_calibration <-
vapply(
container$operations,
rlang::inherits_any,
logical(1),
c("numeric_calibration", "probability_calibration")
simonpcouch marked this conversation as resolved.
Show resolved Hide resolved
)

any(operations_are_calibration)
}

# ------------------------------------------------------------------------------

has_preprocessor <- function(workflow) {
Expand All @@ -628,6 +647,10 @@ has_preprocessor_variables <- function(workflow) {
"variables" %in% names(workflow$pre$actions)
}

has_postprocessor <- function(workflow) {
"container" %in% names(workflow$post$actions)
}

has_case_weights <- function(workflow) {
"case_weights" %in% names(workflow$pre$actions)
}
Expand Down
Loading