Skip to content

Commit 63f5b41

Browse files
authored
Merge pull request #242 from tidymodels/multinom-nnet
changes for #209
2 parents a42d307 + f8529fc commit 63f5b41

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+873
-277
lines changed

NEWS.md

+15
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
# parsnip 0.0.5
2+
3+
## Fixes
4+
5+
* A bug ([#206](https://github.com/tidymodels/parsnip/issues/206) and [#234](https://github.com/tidymodels/parsnip/issues/234)) was fixed that caused an error when predicting with a multinomial `glmnet` model.
6+
7+
## Other Changes
8+
9+
* `glmnet` was removed as a dependency since the new version depends on 3.6.0 or greater. Keeping it would constrain `parsnip` to that same requirement. All `glmnet` tests are run locally.
10+
11+
## New Features
12+
13+
* `nnet` was added as an engine to `multinom_reg()` [#209](https://github.com/tidymodels/parsnip/issues/209)
14+
15+
116
# parsnip 0.0.4
217

318
## New Features

R/misc.R

+3-2
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,8 @@ update_main_parameters <- function(args, param) {
276276
}
277277
param <- param[!has_extra_args]
278278

279-
280-
281279
args <- utils::modifyList(args, param)
282280
}
281+
282+
283+

R/multinom_reg.R

+12-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
#' The model can be created using the `fit()` function using the
3535
#' following _engines_:
3636
#' \itemize{
37-
#' \item \pkg{R}: `"glmnet"` (the default)
37+
#' \item \pkg{R}: `"glmnet"` (the default), `"nnet"`
3838
#' \item \pkg{Stan}: `"stan"`
3939
#' \item \pkg{keras}: `"keras"`
4040
#' }
@@ -49,6 +49,10 @@
4949
#'
5050
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "glmnet")}
5151
#'
52+
#' \pkg{nnet}
53+
#'
54+
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "nnet")}
55+
#'
5256
#' \pkg{spark}
5357
#'
5458
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "spark")}
@@ -197,6 +201,13 @@ organize_multnet_prob <- function(x, object) {
197201
as_tibble(x)
198202
}
199203

204+
organize_nnet_prob <- function(x, object) {
205+
format_classprobs(x)
206+
}
207+
208+
209+
210+
200211
# ------------------------------------------------------------------------------
201212
# glmnet call stack for multinomial regression using `predict` when object has
202213
# classes "_multnet" and "model_fit" (for class predictions):

R/multinom_reg_data.R

+82
Original file line numberDiff line numberDiff line change
@@ -226,3 +226,85 @@ set_pred(
226226
x = quote(as.matrix(new_data)))
227227
)
228228
)
229+
230+
231+
# ------------------------------------------------------------------------------
232+
233+
set_model_engine("multinom_reg", "classification", "nnet")
234+
set_dependency("multinom_reg", "nnet", "nnet")
235+
236+
set_model_arg(
237+
model = "multinom_reg",
238+
eng = "nnet",
239+
parsnip = "penalty",
240+
original = "decay",
241+
func = list(pkg = "dials", fun = "penalty"),
242+
has_submodel = FALSE
243+
)
244+
245+
set_fit(
246+
model = "multinom_reg",
247+
eng = "nnet",
248+
mode = "classification",
249+
value = list(
250+
interface = "formula",
251+
protect = c("formula", "data", "weights"),
252+
func = c(pkg = "nnet", fun = "multinom"),
253+
defaults = list(trace = FALSE)
254+
)
255+
)
256+
257+
258+
set_pred(
259+
model = "multinom_reg",
260+
eng = "nnet",
261+
mode = "classification",
262+
type = "class",
263+
value = list(
264+
pre = NULL,
265+
post = NULL,
266+
func = c(fun = "predict"),
267+
args =
268+
list(
269+
object = quote(object$fit),
270+
newdata = quote(new_data),
271+
type = "class"
272+
)
273+
)
274+
)
275+
276+
set_pred(
277+
model = "multinom_reg",
278+
eng = "nnet",
279+
mode = "classification",
280+
type = "prob",
281+
value = list(
282+
pre = NULL,
283+
post = organize_nnet_prob,
284+
func = c(fun = "predict"),
285+
args =
286+
list(
287+
object = quote(object$fit),
288+
newdata = quote(new_data),
289+
type = "prob"
290+
)
291+
)
292+
)
293+
294+
set_pred(
295+
model = "multinom_reg",
296+
eng = "nnet",
297+
mode = "classification",
298+
type = "raw",
299+
value = list(
300+
pre = NULL,
301+
post = NULL,
302+
func = c(fun = "predict"),
303+
args =
304+
list(
305+
object = quote(object$fit),
306+
newdata = quote(new_data)
307+
)
308+
)
309+
)
310+

docs/dev/articles/articles/Classification.html

+23-23
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/dev/articles/articles/Models.html

+31-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)