Description
The problem
I'm having trouble with predicting with a custom model created here tidymodels/tune#203, spesifically. I would like to use custom optimization techniques therefore I needed a custom tidymodel GPR. However, it is seen that the predict function cannot be achieved due to dplyr's filter. I am new to R but I searched on the internet for similar results and tried to comprehend where things could go wrong. I did some combinations in setting pred function with kernlab gpr to see if this is a mathematical issue rather, but the result did not change if not there was a incompatibility error.
I also proceeded with the same data in tidymodels/tune#203 but no luck.
Please notice I had to comment model translate code block and discard workflow as there were issues described.
Please tolerate any misformat or fault caused by me as I am only a newcomer to the R community.
Best,
Reproducible example
library(kernlab)
library(tidymodels)
library(conflicted)
library(dplyr)
library(tidymodels)
library(pso)
library(GA)
library(scales)
conflicts_prefer(kernlab::alpha)
tidymodels_prefer()
conflict_prefer("filter", "dplyr")
set_new_model("gauss_rbf")
set_model_mode(model = "gauss_rbf", mode = "regression")
set_model_engine("gauss_rbf",
mode = "regression",
eng = "kernlab")
set_dependency("gauss_rbf", eng = "kernlab", pkg = "kernlab")
set_model_arg(model = "gauss_rbf",
eng = "kernlab",
parsnip = "rbf_sigma", ## rbf_sigma exists for svm in dials
original = "rbf_sigma",
func = list(pkg = "dials", fun = "rbf_sigma"),
has_submodel = FALSE)
gauss_rbf <- function(mode = "regression", rbf_sigma = NULL) {
# Check for correct mode
if (mode != "regression") {
stop("`mode` should be 'regression'", call. = FALSE)
}
# Capture the arguments in quosures
args <- list(rbf_sigma = rlang::enquo(rbf_sigma))
# Save some empty slots for future parts of the specification
out <- list(args = args,
eng_args = NULL,
mode = mode,
method = NULL,
engine = NULL)
# set classes in the correct order
class(out) <- make_classes("gauss_rbf")
out
}
gauss_rbf_wrap <- function(formula, data, rbf_sigma) {
model <- kernlab::gausspr(x = formula,
data = data,
kpar = list(sigma = rbf_sigma),
type = "regression",
kernel = "rbfdot")
return(model)
}
set_fit(model = "gauss_rbf",
eng = "kernlab",
mode = "regression",
value = list(
interface = "formula",
protect = c("formula", "data"),
func = c(fun = "gauss_rbf_wrap"),
defaults = list()))
set_pred(model = "gauss_rbf",
eng = "kernlab",
mode = "regression",
type = "numeric",
value = list(
pre = NULL,
post = function(results, object) as.vector(results) ,
func = c(pkg = "kernlab", fun = "predict"), # need to register
# the package, otherwise predict fails
args = list(
object = expr(object$fit),
newdata = expr(new_data),
type = "response")))
show_model_info("gauss_rbf")
#> Information for `gauss_rbf`
# modes: unknown, regression
#
# engines:
# regression: kernlab
#
# ¹The model can use case weights.
#
# arguments:
# kernlab:
# rbf_sigma --> rbf_sigma
#
# fit modules:
# engine mode
# kernlab regression
#
# prediction modules:
# mode engine methods
# regression kernlab numeric
gauss_model <- gauss_rbf(mode = "regression", rbf_sigma = 0.1) %>%
set_engine("kernlab")
# ERROR :
# Model fit template:
# Error in `call2()`:
# ! `ns` must be a string
# Run `rlang::last_error()` to see where the error occurred.
#
# gauss_model %>%
# translate()
data("attrition")
attrition <- attrition %>%
mutate(JobInvolvement = as.factor(JobInvolvement),
Education = as.factor(Education),
EnvironmentSatisfaction = as.factor(EnvironmentSatisfaction),
PerformanceRating = as.factor(PerformanceRating),
RelationshipSatisfaction = as.factor(RelationshipSatisfaction),
WorkLifeBalance = as.factor(WorkLifeBalance),
Attrition = factor(Attrition, levels = c("Yes", "No")))
set.seed(123)
intrain <- initial_split(attrition, prop = 0.8, strata = "Attrition")
rec <- recipe(Attrition ~ ., data = training(intrain)) %>%
themis::step_upsample(Attrition) %>%
step_scale(all_numeric()) %>%
step_nzv(all_numeric()) %>%
step_pca(all_numeric(), threshold = 0.9) %>%
step_rm(OverTime) %>%
prep()
training <- juice(rec)
testing <- bake(rec, testing(intrain))
training <- training %>%
mutate( Attrition = sapply(factor(Attrition, levels = c("Yes", "No")), unclass))
testing <- testing %>%
mutate( Attrition = sapply(factor(Attrition, levels = c("Yes", "No")), unclass))
gauss_fit <- gauss_model %>%
fit(formula = Attrition ~.,
data = training)
# gauss_fit
#
# parsnip model object
#
# Gaussian Processes object of class "gausspr"
# Problem type: regression
#
# Gaussian Radial Basis kernel function.
# Hyperparameter : sigma = 0.1
#
# Number of training instances learned : 1972
# Train error : 0.151315008
#ERROR:
# Error in if (nrow(out) != 1L) { : argument is of length zero
# workflow() %>%
# add_model(gauss_model) %>%
# add_formula(Attrition ~ .) %>%
# fit(data = training)
fit(gauss_model, Attrition ~ ., data = training)
#
# parsnip model object
#
# Gaussian Processes object of class "gausspr"
# Problem type: regression
#
# Gaussian Radial Basis kernel function.
# Hyperparameter : sigma = 0.1
#
# Number of training instances learned : 1972
# Train error : 0.151315008
predict(gauss_fit,
testing)
The output is:
Error in UseMethod("filter") :
no applicable method for 'filter' applied to an object of class "NULL"
The traceback:
> traceback()
9: dplyr::filter(., mode == object$spec$mode, engine == object$spec$engine)
8: dplyr::pull(., remove_intercept)
7: get_encoding(class(object$spec)[1]) %>% dplyr::filter(mode ==
object$spec$mode, engine == object$spec$engine) %>% dplyr::pull(remove_intercept)
6: prepare_data(object, new_data)
5: predict_numeric.model_fit(object = object, new_data = new_data,
...)
4: predict_numeric(object = object, new_data = new_data, ...)
3: predict.model_fit(gauss_fit, testing)
2: predict(gauss_fit, testing)
1: predict(gauss_fit, testing)
Created on 2023-03-08 with reprex v2.0.2
Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#> setting value
#> version R version 4.2.2 (2022-10-31 ucrt)
#> os Windows 10 x64 (build 19044)
#> system x86_64, mingw32
#> ui RTerm
#> language (EN)
#> collate English_United States.utf8
#> ctype English_United States.utf8
#> tz Europe/Paris
#> date 2023-03-08
#> pandoc 2.19.2 @ C:/Program Files/RStudio/resources/app/bin/quarto/bin/tools/ (via rmarkdown)
#>
#> ─ Packages ───────────────────────────────────────────────────────────────────
#> package * version date (UTC) lib source
#> cli 3.6.0 2023-01-09 [1] CRAN (R 4.2.2)
#> clipr 0.8.0 2022-02-22 [1] CRAN (R 4.2.2)
#> digest 0.6.31 2022-12-11 [1] CRAN (R 4.2.2)
#> evaluate 0.20 2023-01-17 [1] CRAN (R 4.2.2)
#> fastmap 1.1.1 2023-02-24 [1] CRAN (R 4.2.2)
#> fs 1.6.1 2023-02-06 [1] CRAN (R 4.2.2)
#> glue 1.6.2 2022-02-24 [1] CRAN (R 4.2.2)
#> htmltools 0.5.4 2022-12-07 [1] CRAN (R 4.2.2)
#> knitr 1.42 2023-01-25 [1] CRAN (R 4.2.2)
#> lifecycle 1.0.3 2022-10-07 [1] CRAN (R 4.2.2)
#> reprex 2.0.2 2022-08-17 [1] CRAN (R 4.2.2)
#> rlang 1.0.6 2022-09-24 [1] CRAN (R 4.2.2)
#> rmarkdown 2.20 2023-01-19 [1] CRAN (R 4.2.2)
#> rstudioapi 0.14 2022-08-22 [1] CRAN (R 4.2.2)
#> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.2.2)
#> withr 2.5.0 2022-03-03 [1] CRAN (R 4.2.2)
#> xfun 0.37 2023-01-31 [1] CRAN (R 4.2.2)
#> yaml 2.3.7 2023-01-23 [1] CRAN (R 4.2.2)
#>
#> [1] C:/Users/yusuf/AppData/Local/R/win-library/4.2
#> [2] C:/Program Files/R/R-4.2.2/library
#>
#> ──────────────────────────────────────────────────────────────────────────────