diff --git a/DESCRIPTION b/DESCRIPTION index d3b570139..14121dc61 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: parsnip Title: A Common API to Modeling and Analysis Functions -Version: 1.0.4.9002 +Version: 1.0.4.9003 Authors@R: c( person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre")), person("Davis", "Vaughan", , "davis@posit.co", role = "aut"), diff --git a/NEWS.md b/NEWS.md index 436703c8e..ecece24a7 100644 --- a/NEWS.md +++ b/NEWS.md @@ -16,6 +16,8 @@ * Several internal functions (to help work with `Surv` objects) were added as a standalone file that can be used in other packages via `usethis::use_standalone("tidymodels/parsnip")`. +* `multi_predict()` methods for `linear_reg()`, `logistic_reg()`, and `multinomial_reg()` models fitted with the `"glmnet"` engine now check the `type` better and error accordingly (#900). + * Rather than being implemented in each method, the check for the `new_data` argument being mistakenly passed as `newdata` to `multi_predict()` now happens in the generic. Packages re-exporting the `multi_predict()` generic and implementing now-duplicate checks may see new failures and can remove their own analogous checks. This check already existed in all `predict()` methods (via `predict.model_fit()`) and all parsnip `multi_predict()` methods (#525). * `logistic_reg()` will now warn at `fit()` when the outcome has more than two levels (#545). diff --git a/R/glmnet-engines.R b/R/glmnet-engines.R index 451848784..1384152cd 100644 --- a/R/glmnet-engines.R +++ b/R/glmnet-engines.R @@ -173,14 +173,20 @@ multi_predict_glmnet <- function(object, type = NULL, penalty = NULL, ...) { + type <- check_pred_type(object, type) + check_spec_pred_type(object, type) + if (type == "prob") { + check_spec_levels(object) + } + + dots <- list(...) + if (object$spec$mode == "classification") { if (is_quosure(penalty)) { penalty <- eval_tidy(penalty) } } - dots <- list(...) - object$spec <- eval_args(object$spec) if (is.null(penalty)) { @@ -195,12 +201,6 @@ multi_predict_glmnet <- function(object, model_type <- class(object$spec)[1] if (object$spec$mode == "classification") { - if (is.null(type)) { - type <- "class" - } - if (!(type %in% c("class", "prob", "link", "raw"))) { - rlang::abort("`type` should be either 'class', 'link', 'raw', or 'prob'.") - } if (type == "prob" | model_type == "logistic_reg") { dots$type <- "response"