Skip to content

Commit 897c927

Browse files
committed
fix some parameter mappings between parsnip and the underlying model function for #238
1 parent 42d5ba5 commit 897c927

7 files changed

+33
-24
lines changed

R/boost_tree_data.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ set_model_arg(
309309
model = "boost_tree",
310310
eng = "spark",
311311
parsnip = "min_info_gain",
312-
original = "gamma",
312+
original = "loss_reduction",
313313
func = list(pkg = "dials", fun = "loss_reduction"),
314314
has_submodel = FALSE
315315
)

R/linear_reg_data.R

+9
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,15 @@ set_model_engine("linear_reg", "regression", "keras")
342342
set_dependency("linear_reg", "keras", "keras")
343343
set_dependency("linear_reg", "keras", "magrittr")
344344

345+
set_model_arg(
346+
model = "linear_reg",
347+
eng = "keras",
348+
parsnip = "penalty",
349+
original = "penalty",
350+
func = list(pkg = "dials", fun = "penalty"),
351+
has_submodel = FALSE
352+
)
353+
345354
set_fit(
346355
model = "linear_reg",
347356
eng = "keras",

R/logistic_reg_data.R

+3-3
Original file line numberDiff line numberDiff line change
@@ -288,9 +288,9 @@ set_dependency("logistic_reg", "keras", "magrittr")
288288
set_model_arg(
289289
model = "logistic_reg",
290290
eng = "keras",
291-
parsnip = "decay",
292-
original = "decay",
293-
func = list(pkg = "dials", fun = "weight_decay"),
291+
parsnip = "penalty",
292+
original = "penalty",
293+
func = list(pkg = "dials", fun = "penalty"),
294294
has_submodel = FALSE
295295
)
296296

R/mlp.R

+10-10
Original file line numberDiff line numberDiff line change
@@ -265,12 +265,12 @@ class2ind <- function (x, drop2nd = FALSE) {
265265
#' @param x A data frame or matrix of predictors
266266
#' @param y A vector (factor or numeric) or matrix (numeric) of outcome data.
267267
#' @param hidden_units An integer for the number of hidden units.
268-
#' @param decay A non-negative real number for the amount of weight decay. Either
268+
#' @param penalty A non-negative real number for the amount of weight decay. Either
269269
#' this parameter _or_ `dropout` can specified.
270270
#' @param dropout The proportion of parameters to set to zero. Either
271-
#' this parameter _or_ `decay` can specified.
271+
#' this parameter _or_ `penalty` can specified.
272272
#' @param epochs An integer for the number of passes through the data.
273-
#' @param act A character string for the type of activation function between layers.
273+
#' @param activation A character string for the type of activation function between layers.
274274
#' @param seeds A vector of three positive integers to control randomness of the
275275
#' calculations.
276276
#' @param ... Currently ignored.
@@ -279,11 +279,11 @@ class2ind <- function (x, drop2nd = FALSE) {
279279
#' @export
280280
keras_mlp <-
281281
function(x, y,
282-
hidden_units = 5, decay = 0, dropout = 0, epochs = 20, act = "softmax",
282+
hidden_units = 5, penalty = 0, dropout = 0, epochs = 20, activation = "softmax",
283283
seeds = sample.int(10^5, size = 3),
284284
...) {
285285

286-
if (decay > 0 & dropout > 0) {
286+
if (penalty > 0 & dropout > 0) {
287287
stop("Please use either dropoput or weight decay.", call. = FALSE)
288288
}
289289
if (!is.matrix(x)) {
@@ -307,20 +307,20 @@ keras_mlp <-
307307

308308
model <- keras::keras_model_sequential()
309309

310-
if (decay > 0) {
310+
if (penalty > 0) {
311311
model %>%
312312
keras::layer_dense(
313313
units = hidden_units,
314-
activation = act,
314+
activation = activation,
315315
input_shape = ncol(x),
316-
kernel_regularizer = keras::regularizer_l2(decay),
316+
kernel_regularizer = keras::regularizer_l2(penalty),
317317
kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[1])
318318
)
319319
} else {
320320
model %>%
321321
keras::layer_dense(
322322
units = hidden_units,
323-
activation = act,
323+
activation = activation,
324324
input_shape = ncol(x),
325325
kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[1])
326326
)
@@ -330,7 +330,7 @@ keras_mlp <-
330330
model %>%
331331
keras::layer_dense(
332332
units = hidden_units,
333-
activation = act,
333+
activation = activation,
334334
input_shape = ncol(x),
335335
kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[1])
336336
) %>%

R/mlp_data.R

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ set_model_arg(
2424
eng = "keras",
2525
parsnip = "penalty",
2626
original = "penalty",
27-
func = list(pkg = "dials", fun = "weight_decay"),
27+
func = list(pkg = "dials", fun = "penalty"),
2828
has_submodel = FALSE
2929
)
3030
set_model_arg(
@@ -188,7 +188,7 @@ set_model_arg(
188188
eng = "nnet",
189189
parsnip = "penalty",
190190
original = "decay",
191-
func = list(pkg = "dials", fun = "weight_decay"),
191+
func = list(pkg = "dials", fun = "penalty"),
192192
has_submodel = FALSE
193193
)
194194
set_model_arg(

R/multinom_reg_data.R

+3-3
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,9 @@ set_dependency("multinom_reg", "keras", "magrittr")
172172
set_model_arg(
173173
model = "multinom_reg",
174174
eng = "keras",
175-
parsnip = "decay",
176-
original = "decay",
177-
func = list(pkg = "dials", fun = "weight_decay"),
175+
parsnip = "penalty",
176+
original = "penalty",
177+
func = list(pkg = "dials", fun = "penalty"),
178178
has_submodel = FALSE
179179
)
180180

man/keras_mlp.Rd

+5-5
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)