Skip to content

Commit 9589f01

Browse files
committed
fix two bugs in brulee tunable methods
1 parent d4be8e4 commit 9589f01

File tree

3 files changed

+6
-29
lines changed

3 files changed

+6
-29
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: parsnip
22
Title: A Common API to Modeling and Analysis Functions
3-
Version: 1.2.1.9002
3+
Version: 1.2.1.9003
44
Authors@R: c(
55
person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre")),
66
person("Davis", "Vaughan", , "davis@posit.co", role = "aut"),

NAMESPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ S3method(tunable,logistic_reg)
132132
S3method(tunable,mars)
133133
S3method(tunable,mlp)
134134
S3method(tunable,model_spec)
135-
S3method(tunable,multinomial_reg)
135+
S3method(tunable,multinom_reg)
136136
S3method(tunable,rand_forest)
137137
S3method(tunable,survival_reg)
138138
S3method(tunable,svm_poly)

R/tunable.R

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ tune_sched <- c("none", "decay_time", "decay_expo", "cyclic", "step")
214214
brulee_args <-
215215
tibble::tibble(
216216
name = c('epochs', 'hidden_units', 'hidden_units_2', 'activation', 'activation_2',
217-
'penalty', 'dropout', 'learn_rate', 'momentum', 'batch_size',
217+
'penalty', 'mixture', 'dropout', 'learn_rate', 'momentum', 'batch_size',
218218
'class_weights', 'stop_iter', 'rate_schedule'),
219219
call_info = list(
220220
list(pkg = "dials", fun = "epochs", range = c(5L, 500L)),
@@ -223,6 +223,7 @@ brulee_args <-
223223
list(pkg = "dials", fun = "activation", values = tune_activations),
224224
list(pkg = "dials", fun = "activation_2", values = tune_activations),
225225
list(pkg = "dials", fun = "penalty"),
226+
list(pkg = "dials", fun = "mixture"),
226227
list(pkg = "dials", fun = "dropout"),
227228
list(pkg = "dials", fun = "learn_rate", range = c(-3, -1/5)),
228229
list(pkg = "dials", fun = "momentum", range = c(0.50, 0.95)),
@@ -253,34 +254,10 @@ tunable.linear_reg <- function(x, ...) {
253254
}
254255

255256
#' @export
256-
tunable.logistic_reg <- function(x, ...) {
257-
res <- NextMethod()
258-
if (x$engine == "glmnet") {
259-
res$call_info[res$name == "mixture"] <-
260-
list(list(pkg = "dials", fun = "mixture", range = c(0.05, 1.00)))
261-
} else if (x$engine == "brulee") {
262-
res <-
263-
brulee_args %>%
264-
dplyr::filter(name %in% tune_args(x)$name) %>%
265-
dplyr::full_join(res %>% dplyr::select(-call_info), by = "name")
266-
}
267-
res
268-
}
257+
tunable.logistic_reg <- tunable.linear_reg
269258

270259
#' @export
271-
tunable.multinomial_reg <- function(x, ...) {
272-
res <- NextMethod()
273-
if (x$engine == "glmnet") {
274-
res$call_info[res$name == "mixture"] <-
275-
list(list(pkg = "dials", fun = "mixture", range = c(0.05, 1.00)))
276-
} else if (x$engine == "brulee") {
277-
res <-
278-
brulee_args %>%
279-
dplyr::filter(name %in% tune_args(x)$name) %>%
280-
dplyr::full_join(res %>% dplyr::select(-call_info), by = "name")
281-
}
282-
res
283-
}
260+
tunable.multinom_reg <- tunable.linear_reg
284261

285262
#' @export
286263
tunable.boost_tree <- function(x, ...) {

0 commit comments

Comments
 (0)