@@ -214,7 +214,7 @@ tune_sched <- c("none", "decay_time", "decay_expo", "cyclic", "step")
214
214
brulee_args <-
215
215
tibble :: tibble(
216
216
name = c(' epochs' , ' hidden_units' , ' hidden_units_2' , ' activation' , ' activation_2' ,
217
- ' penalty' , ' dropout' , ' learn_rate' , ' momentum' , ' batch_size' ,
217
+ ' penalty' , ' mixture ' , ' dropout' , ' learn_rate' , ' momentum' , ' batch_size' ,
218
218
' class_weights' , ' stop_iter' , ' rate_schedule' ),
219
219
call_info = list (
220
220
list (pkg = " dials" , fun = " epochs" , range = c(5L , 500L )),
@@ -223,6 +223,7 @@ brulee_args <-
223
223
list (pkg = " dials" , fun = " activation" , values = tune_activations ),
224
224
list (pkg = " dials" , fun = " activation_2" , values = tune_activations ),
225
225
list (pkg = " dials" , fun = " penalty" ),
226
+ list (pkg = " dials" , fun = " mixture" ),
226
227
list (pkg = " dials" , fun = " dropout" ),
227
228
list (pkg = " dials" , fun = " learn_rate" , range = c(- 3 , - 1 / 5 )),
228
229
list (pkg = " dials" , fun = " momentum" , range = c(0.50 , 0.95 )),
@@ -253,34 +254,10 @@ tunable.linear_reg <- function(x, ...) {
253
254
}
254
255
255
256
# ' @export
256
- tunable.logistic_reg <- function (x , ... ) {
257
- res <- NextMethod()
258
- if (x $ engine == " glmnet" ) {
259
- res $ call_info [res $ name == " mixture" ] <-
260
- list (list (pkg = " dials" , fun = " mixture" , range = c(0.05 , 1.00 )))
261
- } else if (x $ engine == " brulee" ) {
262
- res <-
263
- brulee_args %> %
264
- dplyr :: filter(name %in% tune_args(x )$ name ) %> %
265
- dplyr :: full_join(res %> % dplyr :: select(- call_info ), by = " name" )
266
- }
267
- res
268
- }
257
+ tunable.logistic_reg <- tunable.linear_reg
269
258
270
259
# ' @export
271
- tunable.multinomial_reg <- function (x , ... ) {
272
- res <- NextMethod()
273
- if (x $ engine == " glmnet" ) {
274
- res $ call_info [res $ name == " mixture" ] <-
275
- list (list (pkg = " dials" , fun = " mixture" , range = c(0.05 , 1.00 )))
276
- } else if (x $ engine == " brulee" ) {
277
- res <-
278
- brulee_args %> %
279
- dplyr :: filter(name %in% tune_args(x )$ name ) %> %
280
- dplyr :: full_join(res %> % dplyr :: select(- call_info ), by = " name" )
281
- }
282
- res
283
- }
260
+ tunable.multinom_reg <- tunable.linear_reg
284
261
285
262
# ' @export
286
263
tunable.boost_tree <- function (x , ... ) {
0 commit comments