From f8529fc660f2158be8930305defbf30de70e5c66 Mon Sep 17 00:00:00 2001 From: topepo Date: Sun, 1 Dec 2019 21:28:25 -0500 Subject: [PATCH] changes for #209 --- NEWS.md | 15 ++ R/misc.R | 5 +- R/multinom_reg.R | 13 +- R/multinom_reg_data.R | 82 +++++++ .../dev/articles/articles/Classification.html | 46 ++-- docs/dev/articles/articles/Models.html | 33 ++- docs/dev/articles/articles/Regression.html | 30 +-- docs/dev/articles/articles/Scratch.html | 25 ++- docs/dev/articles/articles/Submodels.html | 14 +- docs/dev/articles/parsnip_Intro.html | 2 +- docs/dev/index.html | 14 +- docs/dev/issue_template.html | 200 ++++++++++++++++++ docs/dev/news/index.html | 52 ++++- docs/dev/pkgdown.css | 9 +- docs/dev/pkgdown.yml | 2 +- docs/dev/reference/C5.0_train.html | 5 +- docs/dev/reference/add_rowindex.html | 2 +- docs/dev/reference/boost_tree.html | 31 ++- docs/dev/reference/check_empty_ellipse.html | 2 +- docs/dev/reference/check_times.html | 54 ++--- docs/dev/reference/control_parsnip.html | 2 +- docs/dev/reference/decision_tree.html | 22 +- docs/dev/reference/descriptors.html | 2 +- docs/dev/reference/fit.html | 12 +- docs/dev/reference/get_model_env.html | 2 +- docs/dev/reference/has_multi_predict.html | 2 +- docs/dev/reference/index.html | 4 +- docs/dev/reference/keras_mlp.html | 16 +- docs/dev/reference/lending_club.html | 38 ++-- docs/dev/reference/linear_reg.html | 12 +- docs/dev/reference/logistic_reg.html | 12 +- docs/dev/reference/make_classes.html | 2 +- docs/dev/reference/mars.html | 21 +- docs/dev/reference/mlp.html | 26 ++- docs/dev/reference/model_fit.html | 2 +- docs/dev/reference/model_printer.html | 2 +- docs/dev/reference/model_spec.html | 8 +- docs/dev/reference/multi_predict.html | 37 ++-- docs/dev/reference/multinom_reg.html | 19 +- docs/dev/reference/nearest_neighbor.html | 10 +- docs/dev/reference/null_model.html | 2 +- docs/dev/reference/nullmodel.html | 2 +- docs/dev/reference/predict.model_fit.html | 5 +- docs/dev/reference/rand_forest.html | 16 +- docs/dev/reference/reexports.html | 2 +- docs/dev/reference/rpart_train.html | 13 +- docs/dev/reference/set_args.html | 2 +- docs/dev/reference/set_engine.html | 2 +- docs/dev/reference/set_new_model.html | 2 +- docs/dev/reference/show_call.html | 2 +- docs/dev/reference/surv_reg.html | 5 +- docs/dev/reference/svm_poly.html | 24 ++- docs/dev/reference/svm_rbf.html | 16 +- docs/dev/reference/tidy.model_fit.html | 2 +- docs/dev/reference/translate.html | 2 +- docs/dev/reference/type_sum.model_spec.html | 2 +- docs/dev/reference/varying.html | 2 +- .../reference/varying_args.model_spec.html | 2 +- docs/dev/reference/wa_churn.html | 22 +- docs/dev/reference/xgb_train.html | 17 +- man/multinom_reg.Rd | 6 +- tests/testthat/test_multinom_reg_nnet.R | 115 ++++++++++ 62 files changed, 873 insertions(+), 277 deletions(-) create mode 100644 docs/dev/issue_template.html create mode 100644 tests/testthat/test_multinom_reg_nnet.R diff --git a/NEWS.md b/NEWS.md index 28d516368..5fa518067 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,18 @@ +# parsnip 0.0.5 + +## Fixes + +* A bug ([#206](https://github.com/tidymodels/parsnip/issues/206) and [#234](https://github.com/tidymodels/parsnip/issues/234)) was fixed that caused an error when predicting with a multinomial `glmnet` model. + +## Other Changes + + * `glmnet` was removed as a dependency since the new version depends on 3.6.0 or greater. Keeping it would constrain `parsnip` to that same requirement. All `glmnet` tests are run locally. + +## New Features + + * `nnet` was added as an engine to `multinom_reg()` [#209](https://github.com/tidymodels/parsnip/issues/209) + + # parsnip 0.0.4 ## New Features diff --git a/R/misc.R b/R/misc.R index 754bcebc3..b26ee20b0 100644 --- a/R/misc.R +++ b/R/misc.R @@ -276,7 +276,8 @@ update_main_parameters <- function(args, param) { } param <- param[!has_extra_args] - - args <- utils::modifyList(args, param) } + + + diff --git a/R/multinom_reg.R b/R/multinom_reg.R index 532489eb9..abbf52cd9 100644 --- a/R/multinom_reg.R +++ b/R/multinom_reg.R @@ -34,7 +34,7 @@ #' The model can be created using the `fit()` function using the #' following _engines_: #' \itemize{ -#' \item \pkg{R}: `"glmnet"` (the default) +#' \item \pkg{R}: `"glmnet"` (the default), `"nnet"` #' \item \pkg{Stan}: `"stan"` #' \item \pkg{keras}: `"keras"` #' } @@ -49,6 +49,10 @@ #' #' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "glmnet")} #' +#' \pkg{nnet} +#' +#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "nnet")} +#' #' \pkg{spark} #' #' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "spark")} @@ -197,6 +201,13 @@ organize_multnet_prob <- function(x, object) { as_tibble(x) } +organize_nnet_prob <- function(x, object) { + format_classprobs(x) +} + + + + # ------------------------------------------------------------------------------ # glmnet call stack for multinomial regression using `predict` when object has # classes "_multnet" and "model_fit" (for class predictions): diff --git a/R/multinom_reg_data.R b/R/multinom_reg_data.R index 1ae95896c..b56f9c2cf 100644 --- a/R/multinom_reg_data.R +++ b/R/multinom_reg_data.R @@ -226,3 +226,85 @@ set_pred( x = quote(as.matrix(new_data))) ) ) + + +# ------------------------------------------------------------------------------ + +set_model_engine("multinom_reg", "classification", "nnet") +set_dependency("multinom_reg", "nnet", "nnet") + +set_model_arg( + model = "multinom_reg", + eng = "nnet", + parsnip = "penalty", + original = "decay", + func = list(pkg = "dials", fun = "penalty"), + has_submodel = FALSE +) + +set_fit( + model = "multinom_reg", + eng = "nnet", + mode = "classification", + value = list( + interface = "formula", + protect = c("formula", "data", "weights"), + func = c(pkg = "nnet", fun = "multinom"), + defaults = list(trace = FALSE) + ) +) + + +set_pred( + model = "multinom_reg", + eng = "nnet", + mode = "classification", + type = "class", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "class" + ) + ) +) + +set_pred( + model = "multinom_reg", + eng = "nnet", + mode = "classification", + type = "prob", + value = list( + pre = NULL, + post = organize_nnet_prob, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data), + type = "prob" + ) + ) +) + +set_pred( + model = "multinom_reg", + eng = "nnet", + mode = "classification", + type = "raw", + value = list( + pre = NULL, + post = NULL, + func = c(fun = "predict"), + args = + list( + object = quote(object$fit), + newdata = quote(new_data) + ) + ) +) + diff --git a/docs/dev/articles/articles/Classification.html b/docs/dev/articles/articles/Classification.html index 22e2b5a86..b0c493db9 100644 --- a/docs/dev/articles/articles/Classification.html +++ b/docs/dev/articles/articles/Classification.html @@ -109,12 +109,12 @@

Classification Example

#> Registered S3 method overwritten by 'xts': #> method from #> as.zoo.xts zoo -#> ── Attaching packages ────────────────────────────────────────────────────────────────────────────────────────────────────────── tidymodels 0.0.3 ── -#> ✔ broom 0.5.2 ✔ purrr 0.3.3 -#> ✔ dials 0.0.3 ✔ recipes 0.1.7 -#> ✔ dplyr 0.8.3 ✔ rsample 0.0.5 -#> ✔ infer 0.5.0 ✔ yardstick 0.0.4 -#> ── Conflicts ───────────────────────────────────────────────────────────────────────────────────────────────────────────── tidymodels_conflicts() ── +#> ── Attaching packages ────────────────────────────────────────────────────────────────────────── tidymodels 0.0.3 ── +#> ✔ broom 0.5.2 ✔ purrr 0.3.3 +#> ✔ dials 0.0.3.9002 ✔ recipes 0.1.7.9001 +#> ✔ dplyr 0.8.3 ✔ rsample 0.0.5 +#> ✔ infer 0.5.0 ✔ yardstick 0.0.4 +#> ── Conflicts ───────────────────────────────────────────────────────────────────────────── tidymodels_conflicts() ── #> ✖ purrr::discard() masks scales::discard() #> ✖ dplyr::filter() masks stats::filter() #> ✖ dplyr::lag() masks stats::lag() @@ -153,22 +153,22 @@

Classification Example

nnet_fit #> parsnip model object #> -#> Fit in: 17.7sModel -#> ___________________________________________________________________________ -#> Layer (type) Output Shape Param # -#> =========================================================================== -#> dense (Dense) (None, 5) 115 -#> ___________________________________________________________________________ -#> dense_1 (Dense) (None, 5) 30 -#> ___________________________________________________________________________ -#> dropout (Dropout) (None, 5) 0 -#> ___________________________________________________________________________ -#> dense_2 (Dense) (None, 2) 12 -#> =========================================================================== +#> Fit in: 15.1sModel +#> ________________________________________________________________________________ +#> Layer (type) Output Shape Param # +#> ================================================================================ +#> dense (Dense) (None, 5) 115 +#> ________________________________________________________________________________ +#> dense_1 (Dense) (None, 5) 30 +#> ________________________________________________________________________________ +#> dropout (Dropout) (None, 5) 0 +#> ________________________________________________________________________________ +#> dense_2 (Dense) (None, 2) 12 +#> ================================================================================ #> Total params: 157 #> Trainable params: 157 #> Non-trainable params: 0 -#> ___________________________________________________________________________ +#> ________________________________________________________________________________

In parsnip, the predict function can be used:.

test_results <- 
   credit_test %>%
@@ -190,15 +190,15 @@ 

Classification Example

#> # A tibble: 1 x 3 #> .metric .estimator .estimate #> <chr> <chr> <dbl> -#> 1 accuracy binary 0.801 +#> 1 accuracy binary 0.807 test_results %>% conf_mat(truth = Status, nnet_class) #> Truth #> Prediction bad good -#> bad 184 93 -#> good 129 707
+#> bad 188 90 +#> good 125 710 - -

Look at the formula code that was printed out, one function uses the argument name ntree and the other uses num.trees. parsnip doesn’t require you to know the specific names of the main arguments.

Now suppose that we want to modify the value of mtry based on the number of predictors in the data. Usually, the default value would be floor(sqrt(num_predictors)). To use a pure bagging model would require an mtry value equal to the total number of parameters. There may be cases where you may not know how many predictors are going to be present (perhaps due to the generation of indicator variables or a variable filter) so that might be difficult to know exactly.

-

When the model it being fit by parsnip, data descriptors are made available. These attempt to let you know what you will have available when the model is fit. When a model object is created (say using rand_forest), the values of the arguments that you give it are immediately evaluated… unless you delay them. To delay the evaluation of any argument, you can used rlang::expr to make an expression.

+

When the model it being fit by parsnip, data descriptors are made available. These attempt to let you know what you will have available when the model is fit. When a model object is created (say using rand_forest), the values of the arguments that you give it are immediately evaluated… unless you delay them. To delay the evaluation of any argument, you can used rlang::expr to make an expression.

Two relevant descriptors for what we are about to do are:

For mda::mda(), the main tuning parameter is subclasses which we will rewrite as sub_classes.

set_model_arg(
@@ -209,7 +208,7 @@ 

} # Capture the arguments in quosures - args <- list(sub_classes = rlang::enquo(sub_classes)) + args <- list(sub_classes = rlang::enquo(sub_classes)) # Save some empty slots for future parts of the specification out <- list(args = args, eng_args = NULL, @@ -270,7 +269,7 @@

  • func is the prediction function (in the same format as above). In many cases, packages have a predict method for their model’s class but this is typically not exported. In this case (and the example below), it is simple enough to make a generic call to predict with no associated package.
  • -args is a list of arguments to pass to the prediction function. These will mostly likely be wrapped in rlang::expr so that they are not evaluated when defining the method. For mda, the code would be predict(object, newdata, type = "class"). What is actually given to the function is the parsnip model fit object, which includes a sub-object called fit and this houses the mda model object. If the data need to be a matrix or data frame, you could also use newdata = quote(as.data.frame(newdata)) and so on.
  • +args is a list of arguments to pass to the prediction function. These will mostly likely be wrapped in rlang::expr so that they are not evaluated when defining the method. For mda, the code would be predict(object, newdata, type = "class"). What is actually given to the function is the parsnip model fit object, which includes a sub-object called fit and this houses the mda model object. If the data need to be a matrix or data frame, you could also use newdata = quote(as.data.frame(newdata)) and so on.

    The parsnip prediction code will expect the result to be an unnamed character string or factor. This will be coerced to a factor with the same levels as the original data.

    To add this method to the model environment, a similar set function is used:

    @@ -379,7 +378,7 @@

    mda_fit #> parsnip model object #> -#> Fit in: 25msCall: +#> Fit in: 21msCall: #> mda::mda(formula = formula, data = data, subclasses = ~2) #> #> Dimension: 4 @@ -463,7 +462,7 @@

    fit(mpg ~ ., data = mtcars) #> parsnip model object #> -#> Fit in: 5msCall: +#> Fit in: 3msCall: #> rlm(formula = formula, data = data) #> Converged in 8 iterations #> @@ -559,7 +558,7 @@

    - diff --git a/docs/dev/issue_template.html b/docs/dev/issue_template.html new file mode 100644 index 000000000..3e7132279 --- /dev/null +++ b/docs/dev/issue_template.html @@ -0,0 +1,200 @@ + + + + + + + + +PLEASE READ: Making a new issue for parsnip • parsnip + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    +
    + + + + +
    + +
    +
    + + + +
    +

    name: Bug report or feature request about: Describe a bug you’ve seen or make a case for a new feature —

    +
    + +

    Please follow the template below.

    +

    If the question is related at all to a specific data analysis, please include a minimal reprex (reproducible example). If you’ve never heard of a reprex before, start by reading “What is a reprex”, and follow the advice further down that page.

    +

    Tips:

    +
      +
    • Here is a good example issue: #139

    • +
    • Issues without a reprex will have a lower priority than the others.

    • +
    • We don’t want you to use confidential data; you can blind the data or simulate other data to demonstrate the issue. The functions caret::twoClassSim() or caret::SLC14_1() might be good tools to simulate data for you.

    • +
    • +

      Unless the problem is explicitly about parallel processing, please run sequentially.

      +
        +
      • Even if it about parallel processing, please make sure that it runs sequentially first.
      • +
      +
    • +
    • Please use set.seed() to ensure any randomness in your code is reproducible.

    • +
    • Please check https://stackoverflow.com/ or https://community.rstudio.com/ to see if someone has already asked the same question (see: Yihui’s Rule).

    • +
    • You might need to install these:

    • +
    +
    install.packages(c("reprex", "sessioninfo"), repos = "http://cran.r-project.org")
    +


    +

    When are ready to file the issue, please delete the parts above this line: < – ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ –>

    +
    +

    +The problem

    +

    I’m having trouble with … or

    +

    Have you considered …

    +
    +
    +

    +Reproducible example

    +

    Copy your code to the clipboard and run:

    +
    reprex::reprex(si = TRUE)
    +
    +
    + + +
    + +
    + + + +
    +
    +

    parsnip is a part of the tidymodels ecosystem, a collection of modeling packages designed with common APIs and a shared philosophy.

    +
    + +
    +

    + Developed by Max Kuhn, Davis Vaughan. + Site built by pkgdown. +

    +
    + +
    +
    + + + + + + + + diff --git a/docs/dev/news/index.html b/docs/dev/news/index.html index d4d6f7c95..51dddcfe4 100644 --- a/docs/dev/news/index.html +++ b/docs/dev/news/index.html @@ -131,23 +131,52 @@

    Changelog

    Source: NEWS.md -
    +

    -parsnip 0.0.4 2019-11-02 +parsnip 0.0.5 Unreleased

    +
    +

    +Fixes

    +
      +
    • A bug (#206 and #234) was fixed that caused an error when predicting with a multinomial glmnet model.
    • +
    +
    +
    +

    +Other Changes

    +
      +
    • +glmnet was removed as a dependency since the new version depends on 3.6.0 or greater. Keeping it would constrain parsnip to that same requirement. All glmnet tests are run locally.
    • +
    +

    New Features

    +
    +
    +
    +

    +parsnip 0.0.4 2019-11-02 +

    +
    +

    +New Features

    +
    • The time elapsed during model fitting is stored in the $elapsed slot of the parsnip model object, and is printed when the model object is printed.

    • Some default parameter ranges were updated for SVM, KNN, and MARS models.

    • The model udpate() methods gained a parameters argument for cases when the parameters are contained in a tibble or list.

    • fit_control() is soft-deprecated in favor of control_parsnip().

    -
    +

    -Fixes

    +Fixes
    • A bug was fixed standardizing the output column types of multi_predict and predict for multinom_reg.

    • A bug was fixed related to using data descriptors and fit_xy().

    • @@ -177,9 +206,9 @@

    • For glmnet models, the full regularization path is always fit regardless of the value given to penalty. Previously, the model was fit with passing penalty to glmnet’s lambda argument and the model could only make predictions at those specific values. (#195)

    -
    +

    -New Features

    +New Features
    • add_rowindex() can create a column called .row to a data frame.

    • If a computational engine is not explicitly set, a default will be used. Each default is documented on the corresponding model page. A warning is issued at fit time unless verbosity is zero.

    • @@ -194,17 +223,17 @@

      parsnip 0.0.2 2019-03-22

      Small release driven by changes in sample() in the current r-devel.

      -
      +

      -New Features

      +New Features
      • A “null model” is now available that fits a predictor-free model (using the mean of the outcome for regression or the mode for classification).

      • fit_xy() can take a single column data frame or matrix for y without error

      -
      +

      -Other Changes

      +Other Changes
      • varying_args() now has a full argument to control whether the full set of possible varying arguments is returned (as opposed to only the arguments that are actually varying).

      • fit_control() not returns an S3 method.

      • @@ -295,10 +324,11 @@

      -