diff --git a/R/fit.R b/R/fit.R index 55badc098..ca74aa9ec 100644 --- a/R/fit.R +++ b/R/fit.R @@ -265,6 +265,7 @@ fit_xy.model_spec <- rlang::warn(glue::glue("Engine set to `{object$engine}`.")) } } + y_var <- colnames(y) if (object$engine != "spark" & NCOL(y) == 1 & !(is.vector(y) | is.factor(y))) { if (is.matrix(y)) { @@ -278,6 +279,7 @@ fit_xy.model_spec <- eval_env <- rlang::env() eval_env$x <- x eval_env$y <- y + eval_env$y_var <- y_var eval_env$weights <- weights_to_numeric(case_weights, object) # TODO case weights: pass in eval_env not individual elements diff --git a/R/fit_helpers.R b/R/fit_helpers.R index d08aead4e..5fd9021df 100644 --- a/R/fit_helpers.R +++ b/R/fit_helpers.R @@ -187,11 +187,16 @@ xy_form <- function(object, env, control, ...) { control = control, ... ) - if (is.vector(env$y)) { - data_obj$y_var <- character(0) + if (!is.null(env$y_var)) { + data_obj$y_var <- env$y_var } else { + if (is.vector(env$y)) { + data_obj$y_var <- character(0) + } + data_obj$y_var <- colnames(env$y) } + res$preproc <- data_obj[c("x_var", "y_var")] res }