Skip to content

clarify case weight support in show_model_info() #1102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions R/aaa_models.R
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,7 @@ show_model_info <- function(model) {
) %>%
dplyr::select(engine, mode, has_wts)

engines %>%
engine_weight_info <- engines %>%
dplyr::left_join(weight_info, by = c("engine", "mode")) %>%
dplyr::mutate(
engine = paste0(engine, has_wts),
Expand All @@ -1005,9 +1005,15 @@ show_model_info <- function(model) {
lab = paste0(" ", mode, engine, "\n")
) %>%
dplyr::ungroup() %>%
dplyr::pull(lab) %>%
cat(sep = "")
cat("\n", cli::symbol$sup_1, "The model can use case weights.\n\n", sep = "")
dplyr::pull(lab)

cat(engine_weight_info, sep = "")

if (!all(weight_info$has_wts == "")) {
cat("\n", cli::symbol$sup_1, "The model can use case weights.", sep = "")
}

cat("\n\n")
} else {
cat(" no registered engines.\n\n")
}
Expand Down
98 changes: 98 additions & 0 deletions tests/testthat/_snaps/registration.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,101 @@
Error in `check_mode_for_new_engine()`:
! "regression" is not a known mode for model `sponge()`.

# showing model info

Code
show_model_info("rand_forest")
Output
Information for `rand_forest`
modes: unknown, classification, regression, censored regression

engines:
classification: randomForest, ranger1, spark
regression: randomForest, ranger1, spark

1The model can use case weights.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that local_reproducible_output() translates ¹ to 1.


arguments:
ranger:
mtry --> mtry
trees --> num.trees
min_n --> min.node.size
randomForest:
mtry --> mtry
trees --> ntree
min_n --> nodesize
spark:
mtry --> feature_subset_strategy
trees --> num_trees
min_n --> min_instances_per_node

fit modules:
engine mode
ranger classification
ranger regression
randomForest classification
randomForest regression
spark classification
spark regression

prediction modules:
mode engine methods
classification randomForest class, prob, raw
classification ranger class, conf_int, prob, raw
classification spark class, prob
regression randomForest numeric, raw
regression ranger conf_int, numeric, raw
regression spark numeric


---

Code
show_model_info("mlp")
Output
Information for `mlp`
modes: unknown, classification, regression

engines:
classification: brulee, keras, nnet
regression: brulee, keras, nnet


arguments:
keras:
hidden_units --> hidden_units
penalty --> penalty
dropout --> dropout
epochs --> epochs
activation --> activation
nnet:
hidden_units --> size
penalty --> decay
epochs --> maxit
brulee:
hidden_units --> hidden_units
penalty --> penalty
epochs --> epochs
dropout --> dropout
learn_rate --> learn_rate
activation --> activation

fit modules:
engine mode
keras regression
keras classification
nnet regression
nnet classification
brulee regression
brulee classification

prediction modules:
mode engine methods
classification brulee class, prob
classification keras class, prob, raw
classification nnet class, prob, raw
regression brulee numeric
regression keras numeric, raw
regression nnet numeric, raw


21 changes: 5 additions & 16 deletions tests/testthat/test_registration.R
Original file line number Diff line number Diff line change
Expand Up @@ -496,21 +496,10 @@ test_that('adding a new predict method', {


test_that('showing model info', {
expect_output(
show_model_info("rand_forest"),
"Information for `rand_forest`"
)
expect_output(
show_model_info("rand_forest"),
"trees --> ntree"
)
expect_output(
show_model_info("rand_forest"),
"fit modules:"
)
expect_output(
show_model_info("rand_forest"),
"prediction modules:"
)
expect_snapshot(show_model_info("rand_forest"))

# ensure that we don't mention case weight support when the
# notation would be ambiguous (#1000)
expect_snapshot(show_model_info("mlp"))
})