Skip to content

Commit

Permalink
fix two bugs in brulee tunable methods
Browse files Browse the repository at this point in the history
  • Loading branch information
topepo committed Sep 13, 2024
1 parent d4be8e4 commit 9589f01
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 29 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: parsnip
Title: A Common API to Modeling and Analysis Functions
Version: 1.2.1.9002
Version: 1.2.1.9003
Authors@R: c(
person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre")),
person("Davis", "Vaughan", , "davis@posit.co", role = "aut"),
Expand Down
2 changes: 1 addition & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ S3method(tunable,logistic_reg)
S3method(tunable,mars)
S3method(tunable,mlp)
S3method(tunable,model_spec)
S3method(tunable,multinomial_reg)
S3method(tunable,multinom_reg)
S3method(tunable,rand_forest)
S3method(tunable,survival_reg)
S3method(tunable,svm_poly)
Expand Down
31 changes: 4 additions & 27 deletions R/tunable.R
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ tune_sched <- c("none", "decay_time", "decay_expo", "cyclic", "step")
brulee_args <-
tibble::tibble(
name = c('epochs', 'hidden_units', 'hidden_units_2', 'activation', 'activation_2',
'penalty', 'dropout', 'learn_rate', 'momentum', 'batch_size',
'penalty', 'mixture', 'dropout', 'learn_rate', 'momentum', 'batch_size',
'class_weights', 'stop_iter', 'rate_schedule'),
call_info = list(
list(pkg = "dials", fun = "epochs", range = c(5L, 500L)),
Expand All @@ -223,6 +223,7 @@ brulee_args <-
list(pkg = "dials", fun = "activation", values = tune_activations),
list(pkg = "dials", fun = "activation_2", values = tune_activations),
list(pkg = "dials", fun = "penalty"),
list(pkg = "dials", fun = "mixture"),
list(pkg = "dials", fun = "dropout"),
list(pkg = "dials", fun = "learn_rate", range = c(-3, -1/5)),
list(pkg = "dials", fun = "momentum", range = c(0.50, 0.95)),
Expand Down Expand Up @@ -253,34 +254,10 @@ tunable.linear_reg <- function(x, ...) {
}

#' @export
tunable.logistic_reg <- function(x, ...) {
res <- NextMethod()
if (x$engine == "glmnet") {
res$call_info[res$name == "mixture"] <-
list(list(pkg = "dials", fun = "mixture", range = c(0.05, 1.00)))
} else if (x$engine == "brulee") {
res <-
brulee_args %>%
dplyr::filter(name %in% tune_args(x)$name) %>%
dplyr::full_join(res %>% dplyr::select(-call_info), by = "name")
}
res
}
tunable.logistic_reg <- tunable.linear_reg

#' @export
tunable.multinomial_reg <- function(x, ...) {
res <- NextMethod()
if (x$engine == "glmnet") {
res$call_info[res$name == "mixture"] <-
list(list(pkg = "dials", fun = "mixture", range = c(0.05, 1.00)))
} else if (x$engine == "brulee") {
res <-
brulee_args %>%
dplyr::filter(name %in% tune_args(x)$name) %>%
dplyr::full_join(res %>% dplyr::select(-call_info), by = "name")
}
res
}
tunable.multinom_reg <- tunable.linear_reg

#' @export
tunable.boost_tree <- function(x, ...) {
Expand Down

0 comments on commit 9589f01

Please sign in to comment.