Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

glmnet penalty parameter #195

Closed
topepo opened this issue Jul 20, 2019 · 9 comments
Closed

glmnet penalty parameter #195

topepo opened this issue Jul 20, 2019 · 9 comments

Comments

@topepo
Copy link
Member

topepo commented Jul 20, 2019

Right now, if a modeling function using glmnet gets a specific penalty value, the model cannot make predictions on any other values. If no value is given, the model can predict on anything but there is no default value to be used with predict().

I propose doing what caret does:

  • if a penalty value is given, save it but do not pass that value to the lambda argument to glmnet.
  • attach the supplied penalty value to the the glmnet fit object.
  • when using predict(), use the value attached to the glmnet object.

It is suboptimal to modify the underlying object but that would enable use to have the best of both worlds; predict() works as expected (and without error) and multi_predict() can also be used.

@alexpghayes
Copy link

I'm opposed for the following reasons:

  1. The predictions may differ. If you don't pass lambda, an entire path of values gets fit. You can then get predictions "at any lambda", but really it an interpolation of the predictions from the nearest two values of lambda on the path.

  2. It's computationally more expensive to fit the whole path than to fit a single point on the path.

Neither of these is a major dealbreaker, but they are hard to explain, and abstract the parsnip behavior into something that is too hard to understand. If you pass a specific value of lambda, you really should mean that you want that value of lambda.

If the goal is computational time optimization, that should happen in the tuner package, which should note that multiple lambda have been requested, and then it should form a lambda-path based on these requested values.

@topepo
Copy link
Member Author

topepo commented Jul 20, 2019

For issue 1, I don't see it as an issue. Nobody has ever complained (or noticed) this in caret. Looking at enough path plots, the path appears to be piecewise linear anyway.

For issue 2, according to ?glmnet:

glmnet relies on its warms starts for speed, and its often faster to fit a whole path than compute a single fit.

We could do as you suggest in the tuning package but then we have to write a whole set of conditional logic blocks just for this model. I think the approach above is the lesser of two evils.

@topepo
Copy link
Member Author

topepo commented Jul 20, 2019

This commit is a prototype for the changes. In this case, we avoid adding anything to the glmnet object. This approach feels pretty clean but will require some documentation for the user.

In translate.linear_reg() we

  • remove the lambda argument from the fit call that will be evaluated.

  • evaluate the penalty argument in the specification to make it an actual value instead of a quosure.

In predict._elnet(): if no value is manually specified for penalty, we use the one originally declared by the user in the specification.

An example:

library(parsnip)
options(width = 100)

mod <- linear_reg(penalty = .1) %>% set_engine("glmnet")

# What parsnip _wants_ to do
parsnip:::translate.default(mod)
#> Linear Regression Model Specification (regression)
#> 
#> Main Arguments:
#>   penalty = 0.1
#> 
#> Computational engine: glmnet 
#> 
#> Model fit template:
#> glmnet::glmnet(x = missing_arg(), y = missing_arg(), weights = missing_arg(), 
#>     lambda = 0.1, family = "gaussian")

# after we modify the call for glmnet models
translate(mod)
#> Linear Regression Model Specification (regression)
#> 
#> Main Arguments:
#>   penalty = 0.1
#> 
#> Computational engine: glmnet 
#> 
#> Model fit template:
#> glmnet::glmnet(x = missing_arg(), y = missing_arg(), weights = missing_arg(), 
#>     family = "gaussian")

# note the lack of a lambda argument

mod_fit <- fit(mod, mpg ~ ., data = mtcars)
# Fits the whole path
str(mod_fit$fit$lambda)
#>  num [1:79] 5.15 4.69 4.27 3.89 3.55 ...

# Predictions made for what we asked for (penalty = .1)
predict(mod_fit, mtcars[1:2, -1])
#> # A tibble: 2 x 1
#>   .pred
#>   <dbl>
#> 1  22.5
#> 2  22.1

# or other values
predict(mod_fit, mtcars[1:2, -1], penalty = 0)
#> # A tibble: 2 x 1
#>   .pred
#>   <dbl>
#> 1  22.6
#> 2  22.1

# or multiple values
multi_predict(mod_fit, mtcars[1:2, -1], penalty = 0:1) %>% 
  tidyr::unnest()
#> # A tibble: 4 x 2
#>   penalty .pred
#>     <int> <dbl>
#> 1       0  22.6
#> 2       1  22.2
#> 3       0  22.1
#> 4       1  21.5

Created on 2019-07-20 by the reprex package (v0.2.1)

Session info
devtools::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────────────────────────
#>  setting  value                       
#>  version  R version 3.6.0 (2019-04-26)
#>  os       macOS High Sierra 10.13.6   
#>  system   x86_64, darwin15.6.0        
#>  ui       X11                         
#>  language (EN)                        
#>  collate  en_US.UTF-8                 
#>  ctype    en_US.UTF-8                 
#>  tz       America/New_York            
#>  date     2019-07-20                  
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────────────────────────
#>  package     * version    date       lib source                            
#>  assertthat    0.2.1      2019-03-21 [1] CRAN (R 3.6.0)                    
#>  backports     1.1.4      2019-04-10 [1] CRAN (R 3.6.0)                    
#>  callr         3.2.0      2019-03-15 [1] CRAN (R 3.6.0)                    
#>  cli           1.1.0      2019-03-19 [1] CRAN (R 3.6.0)                    
#>  codetools     0.2-16     2018-12-24 [1] CRAN (R 3.6.0)                    
#>  crayon        1.3.4      2017-09-16 [1] CRAN (R 3.6.0)                    
#>  desc          1.2.0      2018-05-01 [1] CRAN (R 3.6.0)                    
#>  devtools      2.0.2.9000 2019-05-06 [1] Github (r-lib/devtools@92d32cb)   
#>  digest        0.6.20     2019-07-04 [1] CRAN (R 3.6.0)                    
#>  dplyr         0.8.3      2019-07-04 [1] CRAN (R 3.6.0)                    
#>  evaluate      0.14       2019-05-28 [1] CRAN (R 3.6.0)                    
#>  fansi         0.4.0      2018-10-05 [1] CRAN (R 3.6.0)                    
#>  foreach       1.4.4      2017-12-12 [1] CRAN (R 3.6.0)                    
#>  fs            1.3.1      2019-05-06 [1] CRAN (R 3.6.0)                    
#>  generics      0.0.2      2018-11-29 [1] CRAN (R 3.6.0)                    
#>  glmnet        2.0-16     2018-04-02 [1] CRAN (R 3.6.0)                    
#>  glue          1.3.1      2019-03-12 [1] CRAN (R 3.6.0)                    
#>  highr         0.8        2019-03-20 [1] CRAN (R 3.6.0)                    
#>  htmltools     0.3.6      2017-04-28 [1] CRAN (R 3.6.0)                    
#>  iterators     1.0.10     2018-07-13 [1] CRAN (R 3.6.0)                    
#>  knitr         1.23       2019-05-18 [1] CRAN (R 3.6.0)                    
#>  lattice       0.20-38    2018-11-04 [1] CRAN (R 3.6.0)                    
#>  magrittr      1.5        2014-11-22 [1] CRAN (R 3.6.0)                    
#>  Matrix        1.2-17     2019-03-22 [1] CRAN (R 3.6.0)                    
#>  memoise       1.1.0      2017-04-21 [1] CRAN (R 3.6.0)                    
#>  parsnip     * 0.0.2.9000 2019-07-20 [1] local                             
#>  pillar        1.4.2      2019-06-29 [1] CRAN (R 3.6.0)                    
#>  pkgbuild      1.0.3      2019-03-20 [1] CRAN (R 3.6.0)                    
#>  pkgconfig     2.0.2      2018-08-16 [1] CRAN (R 3.6.0)                    
#>  pkgload       1.0.2      2018-10-29 [1] CRAN (R 3.6.0)                    
#>  prettyunits   1.0.2      2015-07-13 [1] CRAN (R 3.6.0)                    
#>  processx      3.4.0      2019-07-03 [1] CRAN (R 3.6.0)                    
#>  ps            1.3.0      2018-12-21 [1] CRAN (R 3.6.0)                    
#>  purrr         0.3.2      2019-03-15 [1] CRAN (R 3.6.0)                    
#>  R6            2.4.0      2019-02-14 [1] CRAN (R 3.6.0)                    
#>  Rcpp          1.0.1      2019-03-17 [1] CRAN (R 3.6.0)                    
#>  remotes       2.0.4.9000 2019-05-24 [1] Github (r-lib/remotes@769daec)    
#>  rlang         0.4.0.9000 2019-07-09 [1] Github (r-lib/rlang@a3d7e47)      
#>  rmarkdown     1.14       2019-07-12 [1] CRAN (R 3.6.0)                    
#>  rprojroot     1.3-2      2018-01-03 [1] CRAN (R 3.6.0)                    
#>  sessioninfo   1.1.1.9000 2019-03-26 [1] Github (r-lib/sessioninfo@dfb3ea8)
#>  stringi       1.4.3      2019-03-12 [1] CRAN (R 3.6.0)                    
#>  stringr       1.4.0      2019-02-10 [1] CRAN (R 3.6.0)                    
#>  testthat      2.1.1.9000 2019-07-19 [1] local                             
#>  tibble        2.1.3      2019-06-06 [1] CRAN (R 3.6.0)                    
#>  tidyr         0.8.3      2019-03-01 [1] CRAN (R 3.6.0)                    
#>  tidyselect    0.2.5      2018-10-11 [1] CRAN (R 3.6.0)                    
#>  usethis       1.5.1.9000 2019-07-08 [1] Github (r-lib/usethis@f48e100)    
#>  utf8          1.1.4      2018-05-24 [1] CRAN (R 3.6.0)                    
#>  vctrs         0.2.0.9000 2019-07-20 [1] Github (r-lib/vctrs@9f6dab0)      
#>  withr         2.1.2      2018-03-15 [1] CRAN (R 3.6.0)                    
#>  xfun          0.8        2019-06-25 [1] CRAN (R 3.6.0)                    
#>  yaml          2.2.0      2018-07-25 [1] CRAN (R 3.6.0)                    
#>  zeallot       0.1.0      2018-01-28 [1] CRAN (R 3.6.0)                    
#> 
#> [1] /Library/Frameworks/R.framework/Versions/3.6/Resources/library

(edit for clarity)

@topepo
Copy link
Member Author

topepo commented Jul 20, 2019

The same strategy also occurs for multi_predict().

topepo added a commit that referenced this issue Jul 21, 2019
@fxdlmatt
Copy link

fxdlmatt commented Aug 5, 2019

Is there any plan to extend this to other families--so far the work was specific to elnet objects?

Also, is there a nice way to pass an offset?

@topepo
Copy link
Member Author

topepo commented Aug 5, 2019

Yes, It is implemented for all glmnet models in parsnip.

Also, is there a nice way to pass an offset?

Not yet. We need to get to that and case weights.

@dhalpern
Copy link

dhalpern commented Jul 9, 2020

Has there been any update on passing an offset by any chance?

@Steviey
Copy link

Steviey commented Sep 20, 2020

Should I use lambda or not?

@topepo
Copy link
Member Author

topepo commented Sep 20, 2020

@dhalpern Not yet

@Steviey Can you be more specific with your question?

@tidymodels tidymodels locked as resolved and limited conversation to collaborators Oct 23, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants