Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Workflows which use case weights calculated in a recipe from non-predictor columns error when attempting to predict(). #204

Open
mikemahoney218 opened this issue Jun 1, 2023 · 1 comment

Comments

@mikemahoney218
Copy link
Member

The problem

Workflows which use dynamic case weights (calculated from non-predictor columns) error when attempting to predict().

This is following a string of issues, most recently tidymodels/hardhat#240 , and is (one of) the causes of tidymodels/hardhat#242 (though there's another issue there, too). The basic idea is that for some forms of modeling, for instance species abundance modeling (and other forms of presence/background data), it makes sense to calculate case weights "on the fly" for each fold (here, as the ratio of presence observations to background in each analysis set). The suggestion was do that by using step_mutate() so that the case weights would be updated during cross-validation. This seems to cause some issues with the resulting workflow, though.

Reproducible example

set.seed(1107)

library(parsnip)
library(recipes)
#> Loading required package: dplyr
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union
#> 
#> Attaching package: 'recipes'
#> The following object is masked from 'package:stats':
#> 
#>     step
library(workflows)

data(ames, package = "modeldata")
ames_model <- ames |> 
  mutate(cwts = hardhat::importance_weights(NA))

ames_recipe <- recipe(
  formula = Sale_Price ~ Longitude + Latitude, # not to be used in real life...
  data = ames_model
) |> 
  recipes::step_mutate(
    cwts = hardhat::importance_weights(abs(Sale_Price - mean(Sale_Price))),
    role = "case_weights"
  )

ames_wflow <- workflow(preprocessor = ames_recipe) |> 
  add_model(linear_reg()) |> 
  add_case_weights(cwts) |>
  fit(ames_model)

predict(ames_wflow, ames)
#> Error in `dplyr::mutate()`:
#> ℹ In argument: `cwts = hardhat::importance_weights(abs(Sale_Price -
#>   mean(Sale_Price)))`.
#> Caused by error:
#> ! object 'Sale_Price' not found
#> Backtrace:
#>      ▆
#>   1. ├─stats::predict(ames_wflow, ames)
#>   2. ├─workflows:::predict.workflow(ames_wflow, ames)
#>   3. │ └─workflows:::forge_predictors(new_data, workflow)
#>   4. │   ├─hardhat::forge(new_data, blueprint = mold$blueprint)
#>   5. │   └─hardhat:::forge.data.frame(new_data, blueprint = mold$blueprint)
#>   6. │     ├─hardhat::run_forge(blueprint, new_data = new_data, outcomes = outcomes)
#>   7. │     └─hardhat:::run_forge.default_recipe_blueprint(...)
#>   8. │       └─hardhat:::forge_recipe_default_process(...)
#>   9. │         ├─recipes::bake(object = rec, new_data = new_data)
#>  10. │         └─recipes:::bake.recipe(object = rec, new_data = new_data)
#>  11. │           ├─recipes::bake(step, new_data = new_data)
#>  12. │           └─recipes:::bake.step_mutate(step, new_data = new_data)
#>  13. │             ├─dplyr::mutate(new_data, !!!object$inputs)
#>  14. │             └─dplyr:::mutate.data.frame(new_data, !!!object$inputs)
#>  15. │               └─dplyr:::mutate_cols(.data, dplyr_quosures(...), by)
#>  16. │                 ├─base::withCallingHandlers(...)
#>  17. │                 └─dplyr:::mutate_col(dots[[i]], data, mask, new_columns)
#>  18. │                   └─mask$eval_all_mutate(quo)
#>  19. │                     └─dplyr (local) eval()
#>  20. ├─hardhat::importance_weights(abs(Sale_Price - mean(Sale_Price)))
#>  21. │ └─hardhat:::vec_cast_named(x, to = double(), x_arg = "x")
#>  22. │   └─vctrs::vec_cast(x, to, ..., call = call)
#>  23. └─base::.handleSimpleError(...)
#>  24.   └─dplyr (local) h(simpleError(msg, call))
#>  25.     └─rlang::abort(message, class = error_class, parent = parent, call = error_call)

Created on 2023-06-01 with reprex v2.0.2

Speculation

It seems like the issue is that hardhat:::run_forge.default_recipe_blueprint() removes non-predictor variables (in this case, the outcome) prematurely, which then causes the recipe to not have the requisite variables for computation. Is there a workaround for this?

@mikemahoney218
Copy link
Member Author

This is the best workaround I can think of -- manually adding a case weights role to the original recipe, and having two different recipes (one for cross-validation, one for the final fit). This feels non-ideal, since I'd prefer only having one recipe floating around my environment if it were possible. I'm also slightly confused why the formula interface for recipe() didn't automatically add the role -- is that intended?

set.seed(1107)

library(parsnip)
library(recipes)
#> Loading required package: dplyr
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union
#> 
#> Attaching package: 'recipes'
#> The following object is masked from 'package:stats':
#> 
#>     step
library(workflows)

data(ames, package = "modeldata")
ames_model <- ames |> 
  mutate(
    cwts = hardhat::importance_weights(abs(Sale_Price - mean(Sale_Price)))
  )

# Should this have automatically added case weights?
recipe(
  formula = Sale_Price ~ Longitude + Latitude,
  data = ames_model
)
#> 
#> ── Recipe ──────────────────────────────────────────────────────────────────────
#> 
#> ── Inputs
#> Number of variables by role
#> outcome:   1
#> predictor: 2

ames_recipe <- recipe(
  ames_model,
  vars = c("Sale_Price", "Longitude", "Latitude", "cwts"),
  roles = c("outcome", "predictor", "predictor", "case_weight")
)

ames_recipe_cv <- ames_recipe |> 
  recipes::step_mutate(
    cwts = hardhat::importance_weights(abs(Sale_Price - mean(Sale_Price))),
    role = "case_weights"
  )

ames_wflow <- workflow() |> 
  add_model(linear_reg()) |> 
  add_case_weights(cwts)

ames_wflow |> 
  add_recipe(ames_recipe_cv) |> 
  tune::fit_resamples(rsample::vfold_cv(ames_model)) |> 
  tune::collect_metrics()
#> # A tibble: 2 × 6
#>   .metric .estimator      mean     n   std_err .config             
#>   <chr>   <chr>          <dbl> <int>     <dbl> <chr>               
#> 1 rmse    standard   90610.       10 899.      Preprocessor1_Model1
#> 2 rsq     standard       0.151    10   0.00748 Preprocessor1_Model1

ames_wflow |> 
  add_recipe(ames_recipe) |>
  fit(ames_model) |> 
  predict(ames)
#> # A tibble: 2,930 × 1
#>      .pred
#>      <dbl>
#>  1 245059.
#>  2 241882.
#>  3 240232.
#>  4 232777.
#>  5 294744.
#>  6 294358.
#>  7 293631.
#>  8 286673.
#>  9 286457.
#> 10 289630.
#> # ℹ 2,920 more rows

Created on 2023-06-01 with reprex v2.0.2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant