diff --git a/R/tunable.R b/R/tunable.R index 244bbba30..675a22c79 100644 --- a/R/tunable.R +++ b/R/tunable.R @@ -4,7 +4,8 @@ #' @export tunable.model_spec <- function(x, ...) { - mod_env <- rlang::ns_env("parsnip")$parsnip + + mod_env <- get_model_env() if (is.null(x$engine)) { stop("Please declare an engine first using `set_engine()`.", call. = FALSE) @@ -17,27 +18,35 @@ tunable.model_spec <- function(x, ...) { sep = "", call. = FALSE) } - arg_vals <- - mod_env[[arg_name]] %>% - dplyr::filter(engine == x$engine) %>% - dplyr::select(name = parsnip, call_info = func) %>% - dplyr::full_join( - tibble::tibble(name = c(names(x$args), names(x$eng_args))), - by = "name" - ) %>% - dplyr::mutate( - source = "model_spec", - component = mod_type(x), - component_id = dplyr::if_else(name %in% names(x$args), "main", "engine") + arg_vals <- mod_env[[arg_name]] + arg_vals <- arg_vals[arg_vals$engine == x$engine, c("parsnip", "func")] + names(arg_vals)[names(arg_vals) == "parsnip"] <- "name" + names(arg_vals)[names(arg_vals) == "func"] <- "call_info" + + extra_args <- c(names(x$args), names(x$eng_args)) + extra_args <- extra_args[!extra_args %in% arg_vals$name] + + extra_args_tbl <- + tibble::new_tibble( + list(name = extra_args, call_info = vector("list", vctrs::vec_size(extra_args))), + nrow = vctrs::vec_size(extra_args) ) - if (nrow(arg_vals) > 0) { - has_info <- purrr::map_lgl(arg_vals$call_info, is.null) - rm_list <- !(has_info & (arg_vals$component_id == "main")) + res <- vctrs::vec_rbind(arg_vals, extra_args_tbl) - arg_vals <- arg_vals[rm_list,] + res$source <- "model_spec" + res$component <- mod_type(x) + res$component_id <- "main" + res$component_id[!res$name %in% names(x$args)] <- "engine" + + if (nrow(res) > 0) { + has_info <- purrr::map_lgl(res$call_info, is.null) + rm_list <- !(has_info & (res$component_id == "main")) + + res <- res[rm_list, ] } - arg_vals %>% dplyr::select(name, call_info, source, component, component_id) + + res[, c("name", "call_info", "source", "component", "component_id")] } mod_type <- function(.mod) class(.mod)[class(.mod) != "model_spec"][1]