Skip to content

Fix param names #245

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 2, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@

* `nnet` was added as an engine to `multinom_reg()` [#209](https://github.com/tidymodels/parsnip/issues/209)

## Breaking Changes

* There were some mis-mapped parameters (going between `parsnip` and the underlying model function) for `spark` boosted trees and some `keras` models. See [897c927](https://github.com/tidymodels/parsnip/commit/897c92719332caf7344e7c9c8895ac673517d2c8).


# parsnip 0.0.4

Expand Down
2 changes: 1 addition & 1 deletion R/boost_tree_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ set_model_arg(
model = "boost_tree",
eng = "spark",
parsnip = "min_info_gain",
original = "gamma",
original = "loss_reduction",
func = list(pkg = "dials", fun = "loss_reduction"),
has_submodel = FALSE
)
Expand Down
9 changes: 9 additions & 0 deletions R/linear_reg_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,15 @@ set_model_engine("linear_reg", "regression", "keras")
set_dependency("linear_reg", "keras", "keras")
set_dependency("linear_reg", "keras", "magrittr")

set_model_arg(
model = "linear_reg",
eng = "keras",
parsnip = "penalty",
original = "penalty",
func = list(pkg = "dials", fun = "penalty"),
has_submodel = FALSE
)

set_fit(
model = "linear_reg",
eng = "keras",
Expand Down
6 changes: 3 additions & 3 deletions R/logistic_reg_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,9 @@ set_dependency("logistic_reg", "keras", "magrittr")
set_model_arg(
model = "logistic_reg",
eng = "keras",
parsnip = "decay",
original = "decay",
func = list(pkg = "dials", fun = "weight_decay"),
parsnip = "penalty",
original = "penalty",
func = list(pkg = "dials", fun = "penalty"),
has_submodel = FALSE
)

Expand Down
20 changes: 10 additions & 10 deletions R/mlp.R
Original file line number Diff line number Diff line change
Expand Up @@ -265,12 +265,12 @@ class2ind <- function (x, drop2nd = FALSE) {
#' @param x A data frame or matrix of predictors
#' @param y A vector (factor or numeric) or matrix (numeric) of outcome data.
#' @param hidden_units An integer for the number of hidden units.
#' @param decay A non-negative real number for the amount of weight decay. Either
#' @param penalty A non-negative real number for the amount of weight decay. Either
#' this parameter _or_ `dropout` can specified.
#' @param dropout The proportion of parameters to set to zero. Either
#' this parameter _or_ `decay` can specified.
#' this parameter _or_ `penalty` can specified.
#' @param epochs An integer for the number of passes through the data.
#' @param act A character string for the type of activation function between layers.
#' @param activation A character string for the type of activation function between layers.
#' @param seeds A vector of three positive integers to control randomness of the
#' calculations.
#' @param ... Currently ignored.
Expand All @@ -279,11 +279,11 @@ class2ind <- function (x, drop2nd = FALSE) {
#' @export
keras_mlp <-
function(x, y,
hidden_units = 5, decay = 0, dropout = 0, epochs = 20, act = "softmax",
hidden_units = 5, penalty = 0, dropout = 0, epochs = 20, activation = "softmax",
seeds = sample.int(10^5, size = 3),
...) {

if (decay > 0 & dropout > 0) {
if (penalty > 0 & dropout > 0) {
stop("Please use either dropoput or weight decay.", call. = FALSE)
}
if (!is.matrix(x)) {
Expand All @@ -307,20 +307,20 @@ keras_mlp <-

model <- keras::keras_model_sequential()

if (decay > 0) {
if (penalty > 0) {
model %>%
keras::layer_dense(
units = hidden_units,
activation = act,
activation = activation,
input_shape = ncol(x),
kernel_regularizer = keras::regularizer_l2(decay),
kernel_regularizer = keras::regularizer_l2(penalty),
kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[1])
)
} else {
model %>%
keras::layer_dense(
units = hidden_units,
activation = act,
activation = activation,
input_shape = ncol(x),
kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[1])
)
Expand All @@ -330,7 +330,7 @@ keras_mlp <-
model %>%
keras::layer_dense(
units = hidden_units,
activation = act,
activation = activation,
input_shape = ncol(x),
kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[1])
) %>%
Expand Down
4 changes: 2 additions & 2 deletions R/mlp_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ set_model_arg(
eng = "keras",
parsnip = "penalty",
original = "penalty",
func = list(pkg = "dials", fun = "weight_decay"),
func = list(pkg = "dials", fun = "penalty"),
has_submodel = FALSE
)
set_model_arg(
Expand Down Expand Up @@ -188,7 +188,7 @@ set_model_arg(
eng = "nnet",
parsnip = "penalty",
original = "decay",
func = list(pkg = "dials", fun = "weight_decay"),
func = list(pkg = "dials", fun = "penalty"),
has_submodel = FALSE
)
set_model_arg(
Expand Down
6 changes: 3 additions & 3 deletions R/multinom_reg_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,9 @@ set_dependency("multinom_reg", "keras", "magrittr")
set_model_arg(
model = "multinom_reg",
eng = "keras",
parsnip = "decay",
original = "decay",
func = list(pkg = "dials", fun = "weight_decay"),
parsnip = "penalty",
original = "penalty",
func = list(pkg = "dials", fun = "penalty"),
has_submodel = FALSE
)

Expand Down
12 changes: 6 additions & 6 deletions docs/dev/articles/articles/Classification.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading