@@ -214,7 +214,7 @@ tune_sched <- c("none", "decay_time", "decay_expo", "cyclic", "step")
214214brulee_args <-
215215 tibble :: tibble(
216216 name = c(' epochs' , ' hidden_units' , ' hidden_units_2' , ' activation' , ' activation_2' ,
217- ' penalty' , ' dropout' , ' learn_rate' , ' momentum' , ' batch_size' ,
217+ ' penalty' , ' mixture ' , ' dropout' , ' learn_rate' , ' momentum' , ' batch_size' ,
218218 ' class_weights' , ' stop_iter' , ' rate_schedule' ),
219219 call_info = list (
220220 list (pkg = " dials" , fun = " epochs" , range = c(5L , 500L )),
@@ -223,6 +223,7 @@ brulee_args <-
223223 list (pkg = " dials" , fun = " activation" , values = tune_activations ),
224224 list (pkg = " dials" , fun = " activation_2" , values = tune_activations ),
225225 list (pkg = " dials" , fun = " penalty" ),
226+ list (pkg = " dials" , fun = " mixture" ),
226227 list (pkg = " dials" , fun = " dropout" ),
227228 list (pkg = " dials" , fun = " learn_rate" , range = c(- 3 , - 1 / 5 )),
228229 list (pkg = " dials" , fun = " momentum" , range = c(0.50 , 0.95 )),
@@ -253,34 +254,10 @@ tunable.linear_reg <- function(x, ...) {
253254}
254255
255256# ' @export
256- tunable.logistic_reg <- function (x , ... ) {
257- res <- NextMethod()
258- if (x $ engine == " glmnet" ) {
259- res $ call_info [res $ name == " mixture" ] <-
260- list (list (pkg = " dials" , fun = " mixture" , range = c(0.05 , 1.00 )))
261- } else if (x $ engine == " brulee" ) {
262- res <-
263- brulee_args %> %
264- dplyr :: filter(name %in% tune_args(x )$ name ) %> %
265- dplyr :: full_join(res %> % dplyr :: select(- call_info ), by = " name" )
266- }
267- res
268- }
257+ tunable.logistic_reg <- tunable.linear_reg
269258
270259# ' @export
271- tunable.multinomial_reg <- function (x , ... ) {
272- res <- NextMethod()
273- if (x $ engine == " glmnet" ) {
274- res $ call_info [res $ name == " mixture" ] <-
275- list (list (pkg = " dials" , fun = " mixture" , range = c(0.05 , 1.00 )))
276- } else if (x $ engine == " brulee" ) {
277- res <-
278- brulee_args %> %
279- dplyr :: filter(name %in% tune_args(x )$ name ) %> %
280- dplyr :: full_join(res %> % dplyr :: select(- call_info ), by = " name" )
281- }
282- res
283- }
260+ tunable.multinom_reg <- tunable.linear_reg
284261
285262# ' @export
286263tunable.boost_tree <- function (x , ... ) {
0 commit comments