-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
earth (MARS) support #520
Labels
Comments
I went ahead and implemented a poor man's version of ggpredict. Maybe the code will be useful to someone else: #some dev code for testing that our poor man's ggpredict2 works
#use iris data, fit a linear model, predict the output, plot results
#also then fit the MARS and verify it works too
library(tidyverse)
library(earth)
#> Loading required package: Formula
#> Loading required package: plotmo
#> Loading required package: plotrix
library(ggeffects)
#make data range
make_focal_data_range = function(x, length = 1000) {
#if numeric
if (is.numeric(x)) {
y = seq(min(x), max(x), length.out = length)
} else {
y = unique(x)
}
y
}
#make data range for non-first focal term
make_focal_data_range_ordinal = function(x, centiles = pnorm(seq(-2, 2))) {
#if numeric, find the 0.0228 0.1587 0.5000 0.8413 0.9772 centile values
if (is.numeric(x)) {
y = quantile(x, centiles)
} else {
y = unique(x)
}
y
}
#keep covariates at constant value
make_covar_data = function(x) {
if (is.numeric(x)) {
y = mean(x)
} else {
x_table = table(x)
which_mode = which.max(x_table)
y = x[x == names(x_table)[which_mode]][1]
}
y
}
#prep newdata for model predictions
prep_newdata = function(focal_terms, covar_terms, data) {
#stop if focal terms are in covar terms
if (any(focal_terms %in% covar_terms)) {
stop("focal terms cannot be in covar terms", call. = F)
}
#prep a call for expand_grid
call_args = list()
#make ranges for the first term
for (t in focal_terms) {
#the data values depend on the order
if (t == focal_terms[1]) {
#if it's the first, use continuous range
call_args[[t]] = make_focal_data_range(data[[t]])
} else {
#are explicit values given?
if (str_detect(t, "\\[")) {
#get the term by itself
t_clean = str_remove(t, " \\[.*")
call_args[[t_clean]] = str_match_all(t, "\\d+") %>% extract2(1) %>% str_split(",") %>% unlist() %>% as.numeric()
} else {
#if it's 2nd or later, split the range into ordinals
call_args[[t]] = make_focal_data_range_ordinal(data[[t]])
}
#if it's 2nd or later, split the range into ordinals
call_args[[t]] = make_focal_data_range_ordinal(data[[t]])
}
}
#make cover data
for (t in covar_terms) {
call_args[[t]] = make_covar_data(data[[t]])
}
#expand data
newdata = rlang::exec(
expand_grid,
!!!call_args
)
newdata
}
#get model predictions
get_model_preds = function(model, newdata) {
#get classes
model_classes = class(model)
#get predictions from basic model types
if ("lm" %in% model_classes) {
newdata_preds = predict(model, newdata = as.data.frame(newdata), interval = "confidence") %>%
as_tibble() %>%
set_names(c("pred", "pred_lwr", "pred_upr"))
} else if ("earth" %in% model_classes) {
newdata_preds = tibble(
pred = predict(model, newdata = as.data.frame(newdata)) %>% as.vector(),
pred_lwr = NA,
pred_upr = NA
)
} else {
warning(str_glue("model class `{str_c(model_classes, collapse = ', ')}` may not be supported"))
newdata_preds = predict(model, newdata = as.data.frame(newdata), interval = "confidence") %>%
as_tibble() %>%
set_names(c("pred", "pred_lwr", "pred_upr"))
}
newdata_preds
}
#poor man's ggpredict
ggpredict2 = function(model, focal_terms, covar_terms, data) {
#make newdata data frame
newdata = prep_newdata(focal_terms, covar_terms, data)
#add predictions
newdata_preds = get_model_preds(model, newdata)
bind_cols(
newdata,
newdata_preds
)
}
#lm fit
iris_lm = lm(Sepal.Length ~ Sepal.Width + Petal.Length + Petal.Width + Species, data = iris)
#mars fit
iris_mars = earth(Sepal.Length ~ Sepal.Width + Petal.Length + Petal.Width + Species, data = iris, degree = 2)
iris_mars %>% summary()
#> Call: earth(formula=Sepal.Length~Sepal.Width+Petal.Length+Petal.Width+...),
#> data=iris, degree=2)
#>
#> coefficients
#> (Intercept) 5.0952768
#> Speciesvirginica -0.4561422
#> h(2.5-Sepal.Width) 0.4683702
#> h(Sepal.Width-2.5) 0.5708274
#> h(3.5-Petal.Length) -0.3059570
#> h(Petal.Length-3.5) 0.8366556
#> h(Petal.Width-2.3) -2.4866291
#>
#> Selected 7 of 20 terms, and 4 of 5 predictors
#> Termination condition: Reached nk 21
#> Importance: Petal.Length, Sepal.Width, Speciesvirginica, Petal.Width, ...
#> Number of terms at each degree of interaction: 1 6 (additive model)
#> GCV 0.1041697 RSS 12.46981 GRSq 0.8490941 RSq 0.8779484
#does regular ggpredict work on earth?
ggpredict(iris_mars, terms = "Petal.Length")
#> Error: Models of class `earth` are not yet supported.
#no
#side by side
bind_rows(
ggpredict2(
iris_lm,
focal_terms = "Sepal.Width",
covar_terms = c("Petal.Width", "Petal.Length", "Species"),
data = iris
) %>% mutate(model = "lm"),
ggpredict2(
iris_mars,
focal_terms = "Sepal.Width",
covar_terms = c("Petal.Width", "Petal.Length", "Species"),
data = iris
) %>% mutate(model = "mars")
) %>%
ggplot(aes(x = Sepal.Width, y = pred, color = model)) +
geom_line() +
geom_ribbon(aes(ymin = pred_lwr, ymax = pred_upr), alpha = 0.2, linewidth = 0)
#> Warning in max(ids, na.rm = TRUE): no non-missing arguments to max; returning
#> -Inf #with a second focal term
bind_rows(
ggpredict2(
iris_lm,
focal_terms = c("Sepal.Width", "Petal.Length"),
covar_terms = c("Petal.Width", "Species"),
data = iris
) %>% mutate(model = "lm"),
ggpredict2(
iris_mars,
focal_terms = c("Sepal.Width", "Petal.Length"),
covar_terms = c("Petal.Width", "Species"),
data = iris
) %>% mutate(model = "mars")
) %>%
ggplot(aes(x = Sepal.Width, y = pred, color = factor(round(Petal.Length, 2)))) +
geom_line() +
geom_ribbon(aes(ymin = pred_lwr, ymax = pred_upr), alpha = 0.2, linewidth = 0) +
facet_wrap("model")
#> Warning in max(ids, na.rm = TRUE): no non-missing arguments to max; returning
#> -Inf
#> Warning in max(ids, na.rm = TRUE): no non-missing arguments to max; returning
#> -Inf
#> Warning in max(ids, na.rm = TRUE): no non-missing arguments to max; returning
#> -Inf
#> Warning in max(ids, na.rm = TRUE): no non-missing arguments to max; returning
#> -Inf
#> Warning in max(ids, na.rm = TRUE): no non-missing arguments to max; returning
#> -Inf Created on 2024-05-16 with reprex v2.1.0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
In theory, this should work with some minor modifications. One issue is that earth models cannot return standard errors or confidence intervals due to inherent theoretical limitations. However, one should still be able to plot the model predictions using
ggpredict
. One can do it manually.The text was updated successfully, but these errors were encountered: