1- 
21#  Unit tests are in extratests
32#  nocov start
43
54# ' @export
65tunable.model_spec  <-  function (x , ... ) {
7- 
86  mod_env  <-  get_model_env()
97
108  if  (is.null(x $ engine )) {
@@ -13,9 +11,14 @@ tunable.model_spec <- function(x, ...) {
1311
1412  arg_name  <-  paste0(mod_type(x ), " _args" 
1513  if  (! (any(arg_name  ==  names(mod_env )))) {
16-     stop(" The `parsnip` model database doesn't know about the arguments for " 
17-          " model `" x ), " `. Was it registered?" 
18-          sep  =  " " call.  =  FALSE )
14+     stop(
15+       " The `parsnip` model database doesn't know about the arguments for " 
16+       " model `" 
17+       mod_type(x ),
18+       " `. Was it registered?" 
19+       sep  =  " " 
20+       call.  =  FALSE 
21+     )
1922  }
2023
2124  arg_vals  <-  mod_env [[arg_name ]]
@@ -28,7 +31,10 @@ tunable.model_spec <- function(x, ...) {
2831
2932  extra_args_tbl  <- 
3033    tibble :: new_tibble(
31-       list (name  =  extra_args , call_info  =  vector(" list" vctrs :: vec_size(extra_args ))),
34+       list (
35+         name  =  extra_args ,
36+         call_info  =  vector(" list" vctrs :: vec_size(extra_args ))
37+       ),
3238      nrow  =  vctrs :: vec_size(extra_args )
3339    )
3440
@@ -57,7 +63,7 @@ add_engine_parameters <- function(pset, engines) {
5763  is_engine_param  <-  pset $ name  %in%  engines $ name 
5864  if  (any(is_engine_param )) {
5965    engine_names  <-  pset $ name [is_engine_param ]
60-     pset  <-  pset [! is_engine_param ,]
66+     pset  <-  pset [! is_engine_param ,  ]
6167    pset  <- 
6268      dplyr :: bind_rows(pset , engines  | >  dplyr :: filter(name  %in%  engines $ name ))
6369  }
@@ -213,9 +219,22 @@ tune_sched <- c("none", "decay_time", "decay_expo", "cyclic", "step")
213219
214220brulee_mlp_args  <- 
215221  tibble :: tibble(
216-     name  =  c(' epochs' ' hidden_units' ' hidden_units_2' ' activation' ' activation_2' 
217-              ' penalty' ' mixture' ' dropout' ' learn_rate' ' momentum' ' batch_size' 
218-              ' class_weights' ' stop_iter' ' rate_schedule' 
222+     name  =  c(
223+       ' epochs' 
224+       ' hidden_units' 
225+       ' hidden_units_2' 
226+       ' activation' 
227+       ' activation_2' 
228+       ' penalty' 
229+       ' mixture' 
230+       ' dropout' 
231+       ' learn_rate' 
232+       ' momentum' 
233+       ' batch_size' 
234+       ' class_weights' 
235+       ' stop_iter' 
236+       ' rate_schedule' 
237+     ),
219238    call_info  =  list (
220239      list (pkg  =  " dials" fun  =  " epochs" range  =  c(5L , 500L )),
221240      list (pkg  =  " dials" fun  =  " hidden_units" range  =  c(2L , 50L )),
@@ -225,9 +244,9 @@ brulee_mlp_args <-
225244      list (pkg  =  " dials" fun  =  " penalty" 
226245      list (pkg  =  " dials" fun  =  " mixture" 
227246      list (pkg  =  " dials" fun  =  " dropout" 
228-       list (pkg  =  " dials" fun  =  " learn_rate" range  =  c(- 3 , - 1 / 5 )),
229-       list (pkg  =  " dials" fun  =  " momentum" range  =  c(0.50  , 0.95  )),
230-       list (pkg  =  " dials" fun  =  " batch_size" 
247+       list (pkg  =  " dials" fun  =  " learn_rate" range  =  c(- 3 , - 1   /   5 )),
248+       list (pkg  =  " dials" fun  =  " momentum" range  =  c(0.00  , 0.99  )),
249+       list (pkg  =  " dials" fun  =  " batch_size" ,  range   =  c( 3L ,  8L ) ),
231250      list (pkg  =  " dials" fun  =  " class_weights" 
232251      list (pkg  =  " dials" fun  =  " stop_iter" 
233252      list (pkg  =  " dials" fun  =  " rate_schedule" values  =  tune_sched )
@@ -237,8 +256,13 @@ brulee_mlp_args <-
237256
238257brulee_mlp_only_args  <- 
239258  tibble :: tibble(
240-     name  = 
241-       c(' hidden_units' ' hidden_units_2' ' activation' ' activation_2' ' dropout' 
259+     name  =  c(
260+       ' hidden_units' 
261+       ' hidden_units_2' 
262+       ' activation' 
263+       ' activation_2' 
264+       ' dropout' 
265+     )
242266  )
243267
244268#  ------------------------------------------------------------------------------
@@ -256,7 +280,11 @@ tunable.linear_reg <- function(x, ...) {
256280      dplyr :: filter(name  !=  " class_weights" | > 
257281      dplyr :: mutate(
258282        component  =  " linear_reg" 
259-         component_id  =  ifelse(name  %in%  names(formals(" linear_reg" " main" " engine" 
283+         component_id  =  ifelse(
284+           name  %in%  names(formals(" linear_reg" 
285+           " main" 
286+           " engine" 
287+         )
260288      ) | > 
261289      dplyr :: select(name , call_info , source , component , component_id )
262290  }
@@ -277,7 +305,11 @@ tunable.logistic_reg <- function(x, ...) {
277305      dplyr :: anti_join(brulee_mlp_only_args , by  =  " name" | > 
278306      dplyr :: mutate(
279307        component  =  " logistic_reg" 
280-         component_id  =  ifelse(name  %in%  names(formals(" logistic_reg" " main" " engine" 
308+         component_id  =  ifelse(
309+           name  %in%  names(formals(" logistic_reg" 
310+           " main" 
311+           " engine" 
312+         )
281313      ) | > 
282314      dplyr :: select(name , call_info , source , component , component_id )
283315  }
@@ -296,7 +328,11 @@ tunable.multinom_reg <- function(x, ...) {
296328      dplyr :: anti_join(brulee_mlp_only_args , by  =  " name" | > 
297329      dplyr :: mutate(
298330        component  =  " multinom_reg" 
299-         component_id  =  ifelse(name  %in%  names(formals(" multinom_reg" " main" " engine" 
331+         component_id  =  ifelse(
332+           name  %in%  names(formals(" multinom_reg" 
333+           " main" 
334+           " engine" 
335+         )
300336      ) | > 
301337      dplyr :: select(name , call_info , source , component , component_id )
302338  }
@@ -311,7 +347,7 @@ tunable.boost_tree <- function(x, ...) {
311347    res $ call_info [res $ name  ==  " sample_size" <- 
312348      list (list (pkg  =  " dials" fun  =  " sample_prop" 
313349    res $ call_info [res $ name  ==  " learn_rate" <- 
314-       list (list (pkg  =  " dials" fun  =  " learn_rate" range  =  c(- 3 , - 1 / 2 )))
350+       list (list (pkg  =  " dials" fun  =  " learn_rate" range  =  c(- 3 , - 1   /   2 )))
315351  } else  if  (x $ engine  ==  " C5.0" 
316352    res  <-  add_engine_parameters(res , c5_boost_engine_args )
317353    res $ call_info [res $ name  ==  " trees" <- 
@@ -357,9 +393,11 @@ tunable.decision_tree <- function(x, ...) {
357393    res  <-  add_engine_parameters(res , c5_tree_engine_args )
358394  } else  if  (x $ engine  ==  " partykit" 
359395    res  <- 
360-       add_engine_parameters(res ,
361-                             partykit_engine_args  | > 
362-                               dplyr :: mutate(component  =  " decision_tree" 
396+       add_engine_parameters(
397+         res ,
398+         partykit_engine_args  | > 
399+           dplyr :: mutate(component  =  " decision_tree" 
400+       )
363401  }
364402  res 
365403}
@@ -386,7 +424,7 @@ tunable.mlp <- function(x, ...) {
386424      ) | > 
387425      dplyr :: select(name , call_info , source , component , component_id )
388426    if  (x $ engine  ==  " brulee" 
389-       res  <-  res [! grepl(" _2" res $ name ),]
427+       res  <-  res [! grepl(" _2" res $ name ),  ]
390428    }
391429  }
392430  res 
@@ -402,4 +440,3 @@ tunable.survival_reg <- function(x, ...) {
402440}
403441
404442#  nocov end
405- 
0 commit comments