diff --git a/R/glmnet-engines.R b/R/glmnet-engines.R new file mode 100644 index 000000000..b22dcf802 --- /dev/null +++ b/R/glmnet-engines.R @@ -0,0 +1,425 @@ +# glmnet call stack using `predict()` when object has +# classes "_" and "model_fit": +# +# predict() +# predict._(penalty = NULL) +# predict_glmnet(penalty = NULL) <-- checks and sets penalty +# predict.model_fit() <-- checks for extra vars in ... +# predict_numeric() +# predict_numeric._() +# predict_numeric_glmnet() +# predict_numeric.model_fit() +# predict.() + + +# glmnet call stack using `multi_predict` when object has +# classes "_" and "model_fit": +# +# multi_predict() +# multi_predict._(penalty = NULL) +# predict._(multi = TRUE) +# predict_glmnet(multi = TRUE) <-- checks and sets penalty +# predict.model_fit() <-- checks for extra vars in ... +# predict_raw() +# predict_raw._() +# predict_raw_glmnet() +# predict_raw.model_fit(opts = list(s = penalty)) +# predict.() + + +predict_glmnet <- function(object, + new_data, + type = NULL, + opts = list(), + penalty = NULL, + multi = FALSE, + ...) { + # See discussion in https://github.com/tidymodels/parsnip/issues/195 + if (is.null(penalty) & !is.null(object$spec$args$penalty)) { + penalty <- object$spec$args$penalty + } + + object$spec$args$penalty <- .check_glmnet_penalty_predict(penalty, object, multi) + + object$spec <- eval_args(object$spec) + predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...) +} + +predict_numeric_glmnet <- function(object, new_data, ...) { + object$spec <- eval_args(object$spec) + predict_numeric.model_fit(object, new_data = new_data, ...) +} + +predict_class_glmnet <- function(object, new_data, ...) { + object$spec <- eval_args(object$spec) + predict_class.model_fit(object, new_data = new_data, ...) +} + +predict_classprob_glmnet <- function(object, new_data, ...) { + object$spec <- eval_args(object$spec) + predict_classprob.model_fit(object, new_data = new_data, ...) +} + +predict_raw_glmnet <- function(object, new_data, opts = list(), ...) { + object$spec <- eval_args(object$spec) + + opts$s <- object$spec$args$penalty + + predict_raw.model_fit(object, new_data = new_data, opts = opts, ...) +} + +# translation of glmnet classes to parsnip models +# elnet ~ linear_reg +# lognet ~ logistic_reg +# multnet ~ multinom_reg +# glmnetfit: that's a catch-all class for glmnet models fitted with a base-R +# family, thus can be any of linear_reg, logistic_reg, multinom_reg, poisson_reg + +#' @export +predict._elnet <- predict_glmnet + +#' @export +predict_numeric._elnet <- predict_numeric_glmnet + +#' @export +predict_raw._elnet <- predict_raw_glmnet + +#' @export +predict._lognet <- predict_glmnet + +#' @export +predict_class._lognet <- predict_class_glmnet + +#' @export +predict_classprob._lognet <- predict_classprob_glmnet + +#' @export +predict_raw._lognet <- predict_raw_glmnet + +#' @export +predict._multnet <- predict_glmnet + +#' @export +predict_class._multnet <- predict_class_glmnet + +#' @export +predict_classprob._multnet <- predict_classprob_glmnet + +#' @export +predict_raw._multnet <- predict_raw_glmnet + +#' @export +predict._glmnetfit <- predict_glmnet + +#' @export +predict_numeric._glmnetfit <- predict_numeric_glmnet + +#' @export +predict_class._glmnetfit <- predict_class_glmnet + +#' @export +predict_classprob._glmnetfit <- predict_classprob_glmnet + +#' @export +predict_raw._glmnetfit <- predict_raw_glmnet + +#' Organize glmnet predictions +#' +#' This function is for developer use and organizes predictions from glmnet +#' models. +#' +#' @param x Predictions as returned by the `predict()` method for glmnet models. +#' @param object An object of class `model_fit`. +#' +#' @rdname glmnet_helpers_prediction +#' @keywords internal +#' @export +.organize_glmnet_pred <- function(x, object) { + unname(x[, 1]) +} + +organize_glmnet_class <- function(x, object) { + prob_to_class_2(x[, 1], object) +} + +organize_glmnet_prob <- function(x, object) { + res <- tibble(v1 = 1 - x[, 1], v2 = x[, 1]) + colnames(res) <- object$lvl + res +} + +organize_multnet_class <- function(x, object) { + if (vec_size(x) > 1) { + x <- x[,1] + } else { + x <- as.character(x) + } + x +} + +organize_multnet_prob <- function(x, object) { + if (vec_size(x) > 1) { + x <- as_tibble(x[,,1]) + } else { + x <- tibble::as_tibble_row(x[,,1]) + } + x +} + +# ------------------------------------------------------------------------- + +multi_predict_glmnet <- function(object, + new_data, + type = NULL, + penalty = NULL, + ...) { + + if (any(names(enquos(...)) == "newdata")) { + rlang::abort("Did you mean to use `new_data` instead of `newdata`?") + } + + if (object$spec$mode == "classification") { + if (is_quosure(penalty)) { + penalty <- eval_tidy(penalty) + } + } + + dots <- list(...) + + object$spec <- eval_args(object$spec) + + if (is.null(penalty)) { + # See discussion in https://github.com/tidymodels/parsnip/issues/195 + if (!is.null(object$spec$args$penalty)) { + penalty <- object$spec$args$penalty + } else { + penalty <- object$fit$lambda + } + } + + model_type <- class(object$spec)[1] + + if (object$spec$mode == "classification") { + if (is.null(type)) { + type <- "class" + } + if (!(type %in% c("class", "prob", "link", "raw"))) { + rlang::abort("`type` should be either 'class', 'link', 'raw', or 'prob'.") + } + if (type == "prob" | + model_type == "logistic_reg") { + dots$type <- "response" + } else { + dots$type <- type + } + } + + pred <- predict(object, new_data = new_data, type = "raw", + opts = dots, penalty = penalty, multi = TRUE) + + + res <- switch( + model_type, + "linear_reg" = format_glmnet_multi_linear_reg(pred, penalty = penalty), + "logistic_reg" = format_glmnet_multi_logistic_reg(pred, + penalty = penalty, + type = type, + lvl = object$lvl), + "multinom_reg" = format_glmnet_multi_multinom_reg(pred, + penalty = penalty, + type = type, + n_rows = nrow(new_data), + lvl = object$lvl) + ) + + res +} + +#' @export +#' @rdname multi_predict +#' @param penalty A numeric vector of penalty values. +multi_predict._elnet <- multi_predict_glmnet + +#' @export +#' @rdname multi_predict +multi_predict._lognet <- multi_predict_glmnet + +#' @export +#' @rdname multi_predict +multi_predict._multnet <- multi_predict_glmnet + +#' @export +multi_predict._glmnetfit <- multi_predict_glmnet + +format_glmnet_multi_linear_reg <- function(pred, penalty) { + param_key <- tibble(group = colnames(pred), penalty = penalty) + pred <- as_tibble(pred) + pred$.row <- 1:nrow(pred) + pred <- gather(pred, group, .pred, -.row) + if (utils::packageVersion("dplyr") >= "1.0.99.9000") { + pred <- full_join(param_key, pred, by = "group", multiple = "all") + } else { + pred <- full_join(param_key, pred, by = "group") + } + pred$group <- NULL + pred <- arrange(pred, .row, penalty) + .row <- pred$.row + pred$.row <- NULL + pred <- split(pred, .row) + names(pred) <- NULL + tibble(.pred = pred) +} + +format_glmnet_multi_logistic_reg <- function(pred, penalty, type, lvl) { + + type <- rlang::arg_match(type, c("class", "prob")) + + penalty_key <- tibble(s = colnames(pred), penalty = penalty) + + pred <- as_tibble(pred) + pred$.row <- seq_len(nrow(pred)) + pred <- tidyr::pivot_longer(pred, -.row, names_to = "s", values_to = ".pred") + + if (type == "class") { + pred <- pred %>% + dplyr::mutate(.pred_class = dplyr::if_else(.pred >= 0.5, lvl[2], lvl[1]), + .pred_class = factor(.pred_class, levels = lvl), + .keep = "unused") + } else { + pred <- pred %>% + dplyr::mutate(.pred_class_2 = 1 - .pred) %>% + rlang::set_names(c(".row", "s", paste0(".pred_", rev(lvl)))) %>% + dplyr::select(c(".row", "s", paste0(".pred_", lvl))) + } + + if (utils::packageVersion("dplyr") >= "1.0.99.9000") { + pred <- dplyr::full_join(penalty_key, pred, by = "s", multiple = "all") + } else { + pred <- dplyr::full_join(penalty_key, pred, by = "s") + } + + pred <- pred %>% + dplyr::select(-s) %>% + dplyr::arrange(penalty) %>% + tidyr::nest(.by = .row, .key = ".pred") %>% + dplyr::select(-.row) + + pred +} + +format_glmnet_multi_multinom_reg <- function(pred, penalty, type, n_rows, lvl) { + format_probs <- function(x) { + x <- as_tibble(x) + names(x) <- paste0(".pred_", names(x)) + nms <- names(x) + x$.row <- 1:nrow(x) + x[, c(".row", nms)] + } + + if (type == "prob") { + pred <- apply(pred, 3, format_probs) + names(pred) <- NULL + pred <- map_dfr(pred, function(x) x) + pred$penalty <- rep(penalty, each = n_rows) + pred <- dplyr::relocate(pred, penalty) + } else { + pred <- + tibble( + .row = rep(1:n_rows, length(penalty)), + penalty = rep(penalty, each = n_rows), + .pred_class = factor(as.vector(pred), levels = lvl) + ) + } + + pred <- arrange(pred, .row, penalty) + .row <- pred$.row + pred$.row <- NULL + pred <- split(pred, .row) + names(pred) <- NULL + tibble(.pred = pred) +} + +# ------------------------------------------------------------------------- + +#' Helper functions for checking the penalty of glmnet models +#' +#' @description +#' These functions are for developer use. +#' +#' `.check_glmnet_penalty_fit()` checks that the model specification for fitting a +#' glmnet model contains a single value. +#' +#' `.check_glmnet_penalty_predict()` checks that the penalty value used for prediction is valid. +#' If called by `predict()`, it needs to be a single value. Multiple values are +#' allowed for `multi_predict()`. +#' +#' @param x An object of class `model_spec`. +#' @rdname glmnet_helpers +#' @keywords internal +#' @export +.check_glmnet_penalty_fit <- function(x) { + pen <- rlang::eval_tidy(x$args$penalty) + + if (length(pen) != 1) { + rlang::abort(c( + "For the glmnet engine, `penalty` must be a single number (or a value of `tune()`).", + glue::glue("There are {length(pen)} values for `penalty`."), + "To try multiple values for total regularization, use the tune package.", + "To predict multiple penalties, use `multi_predict()`" + )) + } +} + +#' @param penalty A penalty value to check. +#' @param object An object of class `model_fit`. +#' @param multi A logical indicating if multiple values are allowed. +#' +#' @rdname glmnet_helpers +#' @keywords internal +#' @export +.check_glmnet_penalty_predict <- function(penalty = NULL, object, multi = FALSE) { + if (is.null(penalty)) { + penalty <- object$fit$lambda + } + + # when using `predict()`, allow for a single lambda + if (!multi) { + if (length(penalty) != 1) { + rlang::abort( + glue::glue( + "`penalty` should be a single numeric value. `multi_predict()` ", + "can be used to get multiple predictions per row of data.", + ) + ) + } + } + + if (length(object$fit$lambda) == 1 && penalty != object$fit$lambda) { + rlang::abort( + glue::glue( + "The glmnet model was fit with a single penalty value of ", + "{object$fit$lambda}. Predicting with a value of {penalty} ", + "will give incorrect results from `glmnet()`." + ) + ) + } + + penalty +} + +set_glmnet_penalty_path <- function(x) { + if (any(names(x$eng_args) == "path_values")) { + # Since we decouple the parsnip `penalty` argument from being the same + # as the glmnet `lambda` value, `path_values` allows users to set the + # path differently from the default that glmnet uses. See + # https://github.com/tidymodels/parsnip/issues/431 + x$method$fit$args$lambda <- x$eng_args$path_values + x$eng_args$path_values <- NULL + x$method$fit$args$path_values <- NULL + } else { + # See discussion in https://github.com/tidymodels/parsnip/issues/195 + x$method$fit$args$lambda <- NULL + } + x +} + diff --git a/R/glmnet.R b/R/glmnet.R deleted file mode 100644 index 7182ffdce..000000000 --- a/R/glmnet.R +++ /dev/null @@ -1,172 +0,0 @@ -# glmnet call stack using `predict()` when object has -# classes "_" and "model_fit": -# -# predict() -# predict._(penalty = NULL) -# predict_glmnet(penalty = NULL) <-- checks and sets penalty -# predict.model_fit() <-- checks for extra vars in ... -# predict_numeric() -# predict_numeric._() -# predict_numeric_glmnet() -# predict_numeric.model_fit() -# predict.() - - -# glmnet call stack using `multi_predict` when object has -# classes "_" and "model_fit": -# -# multi_predict() -# multi_predict._(penalty = NULL) -# predict._(multi = TRUE) -# predict_glmnet(multi = TRUE) <-- checks and sets penalty -# predict.model_fit() <-- checks for extra vars in ... -# predict_raw() -# predict_raw._() -# predict_raw_glmnet() -# predict_raw.model_fit(opts = list(s = penalty)) -# predict.() - - -predict_glmnet <- function(object, - new_data, - type = NULL, - opts = list(), - penalty = NULL, - multi = FALSE, - ...) { - # See discussion in https://github.com/tidymodels/parsnip/issues/195 - if (is.null(penalty) & !is.null(object$spec$args$penalty)) { - penalty <- object$spec$args$penalty - } - - object$spec$args$penalty <- .check_glmnet_penalty_predict(penalty, object, multi) - - object$spec <- eval_args(object$spec) - predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...) -} - -predict_numeric_glmnet <- function(object, new_data, ...) { - object$spec <- eval_args(object$spec) - predict_numeric.model_fit(object, new_data = new_data, ...) -} - -predict_class_glmnet <- function(object, new_data, ...) { - object$spec <- eval_args(object$spec) - predict_class.model_fit(object, new_data = new_data, ...) -} - -predict_classprob_glmnet <- function(object, new_data, ...) { - object$spec <- eval_args(object$spec) - predict_classprob.model_fit(object, new_data = new_data, ...) -} - -predict_raw_glmnet <- function(object, new_data, opts = list(), ...) { - object$spec <- eval_args(object$spec) - - opts$s <- object$spec$args$penalty - - predict_raw.model_fit(object, new_data = new_data, opts = opts, ...) -} - -multi_predict_glmnet <- function(object, - new_data, - type = NULL, - penalty = NULL, - ...) { - - if (any(names(enquos(...)) == "newdata")) { - rlang::abort("Did you mean to use `new_data` instead of `newdata`?") - } - - if (object$spec$mode == "classification") { - if (is_quosure(penalty)) { - penalty <- eval_tidy(penalty) - } - } - - dots <- list(...) - - object$spec <- eval_args(object$spec) - - if (is.null(penalty)) { - # See discussion in https://github.com/tidymodels/parsnip/issues/195 - if (!is.null(object$spec$args$penalty)) { - penalty <- object$spec$args$penalty - } else { - penalty <- object$fit$lambda - } - } - - model_type <- class(object$spec)[1] - - if (object$spec$mode == "classification") { - if (is.null(type)) { - type <- "class" - } - if (!(type %in% c("class", "prob", "link", "raw"))) { - rlang::abort("`type` should be either 'class', 'link', 'raw', or 'prob'.") - } - if (type == "prob" | - model_type == "logistic_reg") { - dots$type <- "response" - } else { - dots$type <- type - } - } - - pred <- predict(object, new_data = new_data, type = "raw", - opts = dots, penalty = penalty, multi = TRUE) - - - res <- switch( - model_type, - "linear_reg" = format_glmnet_multi_linear_reg(pred, penalty = penalty), - "logistic_reg" = format_glmnet_multi_logistic_reg(pred, - penalty = penalty, - type = type, - lvl = object$lvl), - "multinom_reg" = format_glmnet_multi_multinom_reg(pred, - penalty = penalty, - type = type, - n_rows = nrow(new_data), - lvl = object$lvl) - ) - - res -} - -#' @export -predict._glmnetfit <- predict_glmnet - -#' @export -predict_numeric._glmnetfit <- predict_numeric_glmnet - -#' @export -predict_class._glmnetfit <- predict_class_glmnet - -#' @export -predict_classprob._glmnetfit <- predict_classprob_glmnet - -#' @export -predict_raw._glmnetfit <- predict_raw_glmnet - -#' @export -multi_predict._glmnetfit <- multi_predict_glmnet - -# ------------------------------------------------------------------------- - -set_glmnet_penalty_path <- function(x) { - if (any(names(x$eng_args) == "path_values")) { - # Since we decouple the parsnip `penalty` argument from being the same - # as the glmnet `lambda` value, `path_values` allows users to set the - # path differently from the default that glmnet uses. See - # https://github.com/tidymodels/parsnip/issues/431 - x$method$fit$args$lambda <- x$eng_args$path_values - x$eng_args$path_values <- NULL - x$method$fit$args$path_values <- NULL - } else { - # See discussion in https://github.com/tidymodels/parsnip/issues/195 - x$method$fit$args$lambda <- NULL - } - x -} diff --git a/R/linear_reg.R b/R/linear_reg.R index 872c22f8e..f09a268ac 100644 --- a/R/linear_reg.R +++ b/R/linear_reg.R @@ -117,53 +117,3 @@ check_args.linear_reg <- function(object) { invisible(object) } - -# ------------------------------------------------------------------------------ - -#' Organize glmnet predictions -#' -#' This function is for developer use and organizes predictions from glmnet -#' models. -#' -#' @param x Predictions as returned by the `predict()` method for glmnet models. -#' @param object An object of class `model_fit`. -#' -#' @rdname glmnet_helpers_prediction -#' @keywords internal -#' @export -.organize_glmnet_pred <- function(x, object) { - unname(x[, 1]) -} - -#' @export -predict._elnet <- predict_glmnet - -#' @export -predict_numeric._elnet <- predict_numeric_glmnet - -#' @export -predict_raw._elnet <- predict_raw_glmnet - -#' @export -#'@rdname multi_predict -#' @param penalty A numeric vector of penalty values. -multi_predict._elnet <- multi_predict_glmnet - -format_glmnet_multi_linear_reg <- function(pred, penalty) { - param_key <- tibble(group = colnames(pred), penalty = penalty) - pred <- as_tibble(pred) - pred$.row <- 1:nrow(pred) - pred <- gather(pred, group, .pred, -.row) - if (utils::packageVersion("dplyr") >= "1.0.99.9000") { - pred <- full_join(param_key, pred, by = "group", multiple = "all") - } else { - pred <- full_join(param_key, pred, by = "group") - } - pred$group <- NULL - pred <- arrange(pred, .row, penalty) - .row <- pred$.row - pred$.row <- NULL - pred <- split(pred, .row) - names(pred) <- NULL - tibble(.pred = pred) -} diff --git a/R/logistic_reg.R b/R/logistic_reg.R index ce861ed3e..86266494c 100644 --- a/R/logistic_reg.R +++ b/R/logistic_reg.R @@ -164,71 +164,6 @@ prob_to_class_2 <- function(x, object) { unname(x) } -organize_glmnet_class <- function(x, object) { - prob_to_class_2(x[, 1], object) -} - -organize_glmnet_prob <- function(x, object) { - res <- tibble(v1 = 1 - x[, 1], v2 = x[, 1]) - colnames(res) <- object$lvl - res -} - -# ------------------------------------------------------------------------------ - -#' @export -predict._lognet <- predict_glmnet - -#' @export -#' @rdname multi_predict -multi_predict._lognet <- multi_predict_glmnet - -format_glmnet_multi_logistic_reg <- function(pred, penalty, type, lvl) { - - type <- rlang::arg_match(type, c("class", "prob")) - - penalty_key <- tibble(s = colnames(pred), penalty = penalty) - - pred <- as_tibble(pred) - pred$.row <- seq_len(nrow(pred)) - pred <- tidyr::pivot_longer(pred, -.row, names_to = "s", values_to = ".pred") - - if (type == "class") { - pred <- pred %>% - dplyr::mutate(.pred_class = dplyr::if_else(.pred >= 0.5, lvl[2], lvl[1]), - .pred_class = factor(.pred_class, levels = lvl), - .keep = "unused") - } else { - pred <- pred %>% - dplyr::mutate(.pred_class_2 = 1 - .pred) %>% - rlang::set_names(c(".row", "s", paste0(".pred_", rev(lvl)))) %>% - dplyr::select(c(".row", "s", paste0(".pred_", lvl))) - } - - if (utils::packageVersion("dplyr") >= "1.0.99.9000") { - pred <- dplyr::full_join(penalty_key, pred, by = "s", multiple = "all") - } else { - pred <- dplyr::full_join(penalty_key, pred, by = "s") - } - - pred <- pred %>% - dplyr::select(-s) %>% - dplyr::arrange(penalty) %>% - tidyr::nest(.by = .row, .key = ".pred") %>% - dplyr::select(-.row) - - pred -} - -#' @export -predict_class._lognet <- predict_class_glmnet - -#' @export -predict_classprob._lognet <- predict_classprob_glmnet - -#' @export -predict_raw._lognet <- predict_raw_glmnet - # ------------------------------------------------------------------------------ liblinear_preds <- function(results, object) { diff --git a/R/misc.R b/R/misc.R index 0962f2386..b3fb42e35 100644 --- a/R/misc.R +++ b/R/misc.R @@ -459,74 +459,6 @@ stan_conf_int <- function(object, newdata) { # ------------------------------------------------------------------------------ - -#' Helper functions for checking the penalty of glmnet models -#' -#' @description -#' These functions are for developer use. -#' -#' `.check_glmnet_penalty_fit()` checks that the model specification for fitting a -#' glmnet model contains a single value. -#' -#' `.check_glmnet_penalty_predict()` checks that the penalty value used for prediction is valid. -#' If called by `predict()`, it needs to be a single value. Multiple values are -#' allowed for `multi_predict()`. -#' -#' @param x An object of class `model_spec`. -#' @rdname glmnet_helpers -#' @keywords internal -#' @export -.check_glmnet_penalty_fit <- function(x) { - pen <- rlang::eval_tidy(x$args$penalty) - - if (length(pen) != 1) { - rlang::abort(c( - "For the glmnet engine, `penalty` must be a single number (or a value of `tune()`).", - glue::glue("There are {length(pen)} values for `penalty`."), - "To try multiple values for total regularization, use the tune package.", - "To predict multiple penalties, use `multi_predict()`" - )) - } -} - -#' @param penalty A penalty value to check. -#' @param object An object of class `model_fit`. -#' @param multi A logical indicating if multiple values are allowed. -#' -#' @rdname glmnet_helpers -#' @keywords internal -#' @export -.check_glmnet_penalty_predict <- function(penalty = NULL, object, multi = FALSE) { - if (is.null(penalty)) { - penalty <- object$fit$lambda - } - - # when using `predict()`, allow for a single lambda - if (!multi) { - if (length(penalty) != 1) { - rlang::abort( - glue::glue( - "`penalty` should be a single numeric value. `multi_predict()` ", - "can be used to get multiple predictions per row of data.", - ) - ) - } - } - - if (length(object$fit$lambda) == 1 && penalty != object$fit$lambda) { - rlang::abort( - glue::glue( - "The glmnet model was fit with a single penalty value of ", - "{object$fit$lambda}. Predicting with a value of {penalty} ", - "will give incorrect results from `glmnet()`." - ) - ) - } - - penalty -} - - check_case_weights <- function(x, spec) { if (is.null(x) | spec$engine == "spark") { return(invisible(NULL)) diff --git a/R/multinom_reg.R b/R/multinom_reg.R index 465ebf205..776b6ec6c 100644 --- a/R/multinom_reg.R +++ b/R/multinom_reg.R @@ -112,24 +112,6 @@ check_args.multinom_reg <- function(object) { # ------------------------------------------------------------------------------ -organize_multnet_class <- function(x, object) { - if (vec_size(x) > 1) { - x <- x[,1] - } else { - x <- as.character(x) - } - x -} - -organize_multnet_prob <- function(x, object) { - if (vec_size(x) > 1) { - x <- as_tibble(x[,,1]) - } else { - x <- tibble::as_tibble_row(x[,,1]) - } - x -} - organize_nnet_prob <- function(x, object) { if (is.null(nrow(x))) { x_names <- names(x) @@ -138,56 +120,3 @@ organize_nnet_prob <- function(x, object) { } format_classprobs(x) } - - - - -# ------------------------------------------------------------------------------ - -#' @export -predict._multnet <- predict_glmnet - -#' @export -#' @rdname multi_predict -multi_predict._multnet <- multi_predict_glmnet - -#' @export -predict_class._multnet <- predict_class_glmnet - -#' @export -predict_classprob._multnet <- predict_classprob_glmnet - -#' @export -predict_raw._multnet <- predict_raw_glmnet - -format_glmnet_multi_multinom_reg <- function(pred, penalty, type, n_rows, lvl) { - format_probs <- function(x) { - x <- as_tibble(x) - names(x) <- paste0(".pred_", names(x)) - nms <- names(x) - x$.row <- 1:nrow(x) - x[, c(".row", nms)] - } - - if (type == "prob") { - pred <- apply(pred, 3, format_probs) - names(pred) <- NULL - pred <- map_dfr(pred, function(x) x) - pred$penalty <- rep(penalty, each = n_rows) - pred <- dplyr::relocate(pred, penalty) - } else { - pred <- - tibble( - .row = rep(1:n_rows, length(penalty)), - penalty = rep(penalty, each = n_rows), - .pred_class = factor(as.vector(pred), levels = lvl) - ) - } - - pred <- arrange(pred, .row, penalty) - .row <- pred$.row - pred$.row <- NULL - pred <- split(pred, .row) - names(pred) <- NULL - tibble(.pred = pred) -} diff --git a/man/glmnet_helpers.Rd b/man/glmnet_helpers.Rd index 19451a7de..e6fcfc3aa 100644 --- a/man/glmnet_helpers.Rd +++ b/man/glmnet_helpers.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/misc.R +% Please edit documentation in R/glmnet-engines.R \name{.check_glmnet_penalty_fit} \alias{.check_glmnet_penalty_fit} \alias{.check_glmnet_penalty_predict} diff --git a/man/glmnet_helpers_prediction.Rd b/man/glmnet_helpers_prediction.Rd index 444784b9e..71d8f5c1c 100644 --- a/man/glmnet_helpers_prediction.Rd +++ b/man/glmnet_helpers_prediction.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/linear_reg.R +% Please edit documentation in R/glmnet-engines.R \name{.organize_glmnet_pred} \alias{.organize_glmnet_pred} \title{Organize glmnet predictions} diff --git a/man/multi_predict.Rd b/man/multi_predict.Rd index 8d7ecda95..e0b93106e 100644 --- a/man/multi_predict.Rd +++ b/man/multi_predict.Rd @@ -1,7 +1,6 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/aaa_multi_predict.R, R/boost_tree.R, -% R/linear_reg.R, R/logistic_reg.R, R/mars.R, R/mlp.R, R/multinom_reg.R, -% R/nearest_neighbor.R +% R/glmnet-engines.R, R/mars.R, R/mlp.R, R/nearest_neighbor.R \name{multi_predict} \alias{multi_predict} \alias{multi_predict.default} @@ -9,9 +8,9 @@ \alias{multi_predict._C5.0} \alias{multi_predict._elnet} \alias{multi_predict._lognet} +\alias{multi_predict._multnet} \alias{multi_predict._earth} \alias{multi_predict._torch_mlp} -\alias{multi_predict._multnet} \alias{multi_predict._train.kknn} \title{Model predictions across many sub-models} \usage{ @@ -27,12 +26,12 @@ multi_predict(object, ...) \method{multi_predict}{`_lognet`}(object, new_data, type = NULL, penalty = NULL, ...) +\method{multi_predict}{`_multnet`}(object, new_data, type = NULL, penalty = NULL, ...) + \method{multi_predict}{`_earth`}(object, new_data, type = NULL, num_terms = NULL, ...) \method{multi_predict}{`_torch_mlp`}(object, new_data, type = NULL, epochs = NULL, ...) -\method{multi_predict}{`_multnet`}(object, new_data, type = NULL, penalty = NULL, ...) - \method{multi_predict}{`_train.kknn`}(object, new_data, type = NULL, neighbors = NULL, ...) } \arguments{