-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Quantile predictions for linear regression models via lm() #1132
Conversation
Thanks! I think the output structure looks good. And this generalizes the While In our package, we now have two different implementations which sort of illustrate two different (poor) parsnip workaround strategies.
I suspect that the "right" way to do this is to add something like "quantile regression" as a possible mode. Some engines will then be able to implement it. And there are R packages that do the right thing for linear models, random forests, and gradient boosting among other models. A potential concern here is that, as you did above, one needs to specify the required quantiles, at least by fit time. Another option, that hadn't occurred me until now, is to have the required quantiles be an argument to # remotes::install_github("cmu-delphi/epipredict@grf-qr-engine")
# install.packages(c("grf", "quantreg"))
library(epipredict)
library(grf)
library(quantreg)
# 1. Creating an engine that does quantile regression (but it's a linear model)
library(quantreg)
tib <- data.frame(y = rnorm(100), x1 = rnorm(100), x2 = rnorm(100))
rq_spec <- quantile_reg(quantile_levels = c(.3, .5, .8))
ff <- rq_spec %>% fit(y ~ ., data = tib)
predict(ff, new_data = tib[1:5, ]) %>%
# makes a list-col in `.pred`, but using vctrs::new_rec()
pivot_quantiles_wider(.pred) # just to see the results
#> # A tibble: 5 × 3
#> `0.3` `0.5` `0.8`
#> <dbl> <dbl> <dbl>
#> 1 -0.140 0.392 0.682
#> 2 -0.266 0.228 0.726
#> 3 -0.306 0.272 0.704
#> 4 -0.471 0.0172 0.777
#> 5 -0.311 0.216 0.724
# 2. Adding an (quantile) engine, but this engine could also do
# classification/regression (neither implemented)
rf_spec <- rand_forest(mode = "regression") %>%
set_engine(engine = "grf_quantiles", quantiles = c(.1, .5, .7)) # stuck with
# arg names provided by the package
out <- fit(rf_spec, formula = y ~ ., data = tib)
predict(out, new_data = tib[1:5, ]) %>%
pivot_quantiles_wider(.pred)
#> # A tibble: 5 × 3
#> `0.1` `0.5` `0.7`
#> <dbl> <dbl> <dbl>
#> 1 -1.08 0.245 0.556
#> 2 -1.22 0.219 0.434
#> 3 -1.08 0.224 0.401
#> 4 -1.08 0.219 0.538
#> 5 -1.08 0.219 0.444 |
A quick comment is that I don't think this is an "either/or" situation. Being able to back out quantiles from the prediction interval from linear regression, and being able to fit a quantile regression, are two different things, each of which would be nice to have. So perhaps this PR is just about the first (and generalizations thereof), but it would still be useful to accommodate the second (and generalizations thereof) in a separate PR. |
Just piping in here since this feature is an interest of mine...
Another complication to keep in mind for the quantile regression "mode" option is that some engines may be restricted to only estimating one quantile at a time. LightGBM I know has this restriction, I don't know if there are others. |
@ryantibs agreed absolutely. Extracting quantiles from the prediction intervals may well be useful if we trust the model. @joranE's point is important here. I've only used those engines that have the ability to fit multiple quantiles with a single (non-parsnip) model object as the result. But not everything can. Multiple responses in one model object vs one model object per response is perhaps a separate issue, because it also happens outside the quantile context. library(tidymodels)
form <- formula(cbind(mpg, disp) ~ .)
lm_fit <- linear_reg() %>%
fit(form, data = mtcars[, 1:6])
predict(lm_fit, new_data = mtcars[1:5, 1:6])
#> # A tibble: 5 × 2
#> .pred_mpg .pred_disp
#> <dbl> <dbl>
#> 1 23.0 179.
#> 2 22.3 193.
#> 3 25.8 102.
#> 4 20.6 222.
#> 5 17.1 306.
# fails
glm_fit <- linear_reg(engine = "glm") %>%
fit(form, data = mtcars[, 1:6])
#> Error in x[good, , drop = FALSE]: (subscript) logical subscript too long
# this is technically allowed by glmnet
glmnet_err <- linear_reg(penalty = 1) %>%
set_engine("glmnet", family = "mgaussian") %>%
fit(form, data = mtcars[, 1:6]) # works
predict(glmnet_err, new_data = mtcars[1:5, 1:6]) # fails?
#> Error in cbind2(1, newx) %*% (nbeta[[i]]): non-conformable arguments
x <- model.matrix(form, data = mtcars[, 1:6])
predict(extract_fit_engine(glmnet_err), newx = x[1:5, -1], s = 1)
#> , , 1
#>
#> mpg disp
#> Mazda RX4 22.85137 180.0772
#> Mazda RX4 Wag 22.09631 194.4642
#> Datsun 710 25.81773 103.1353
#> Hornet 4 Drive 20.62226 221.3567
#> Hornet Sportabout 17.03880 305.4019
# fails
forest_fit <- rand_forest(mode = "regression") %>%
fit(form, data = mtcars[, 1:6])
#> Error in ranger::ranger(x = maybe_data_frame(x), y = y, num.threads = 1, : Error: Competing risks not supported yet. Use status=1 for events and status=0 for censoring. Created on 2024-07-24 with reprex v2.1.1 |
This pull request has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue. |
Draft for review - unit tests to come.
@ryantibs @dajmcdon
Enable quantile predictions for linear regression models via
lm()
. This is the first PR to populatetype = “quantile”
predictions for existing models. The next few PRs will be for other existing models for existing engines (stan, BART, ranger, and (maybe) lightgbm models). After that, we can implement a few more engines (e.g.,quantreg
).Why not a different mode?
First, we can possibly generate quantile intervals for classification models (e.g., quantiles of the event probability distributions). A good example is for stan logistic models.
Second, the main difference for quantile predictions is the
type
option topredict.model_fit()
. AFAIK we don’t need specialized metrics or other complications (as we had for survival analysis) but someone speak up if I'm being short-sighted.Example:
Created on 2024-07-18 with reprex v2.1.0