Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reporting uncertainty #37

Closed
alexpghayes opened this issue Aug 3, 2018 · 5 comments
Closed

Reporting uncertainty #37

alexpghayes opened this issue Aug 3, 2018 · 5 comments

Comments

@alexpghayes
Copy link

I'm reworking lots of broom::augment() methods at the moment and am discovering that packages do some crazy stuff to report uncertainty. Defining some standards for reporting uncertainty early on seems like a good idea.

For classification problems, reporting the class probabilities makes sense, but can become problematic for outcomes with high cardinalities. Nobody wants 1000 columns of class probabilities. One option is to just report the most likely class along with it's probability, or the top k = 5 or so classes by default.

For regression problems I think there's more nuance. Open questions:

  • Is the best way to report uncertainty in a regression outcome to add a column of standard errors se_fit or similar?
  • How should users specify that they want confidence intervals vs prediction intervals?
  • Should confidence intervals or prediction intervals be the default reporting option?
@topepo
Copy link
Member

topepo commented Aug 3, 2018

For classification problems, reporting the class probabilities makes sense, but can become problematic for outcomes with high cardinalities. Nobody wants 1000 columns of class probabilities. One option is to just report the most likely class along with it's probability, or the top k = 5 or so classes by default.

Right now, there are functions to make predictions in different context (e.g. classes, probabilities etc) and a print method will bind them all via a type argument. Another specific prediction function could be written for this purpose and blended into the method.

  • Is the best way to report uncertainty in a regression outcome to add a column of standard errors se_fit or similar?
  • How should users specify that they want confidence intervals vs prediction intervals?

Those are specific to several models. My thoughts so far is to have a predict_raw function that can be used to access the model's prediction function and allow for ... to be passed through. This can then generate whatever columns that are wanted.

It would be good to have some harmonization (maybe predict_ci and predict_pi functions) here too and that would require a little more (easy) work in parsnips model definitions.

  • Should confidence intervals or prediction intervals be the default reporting option?

I don't think so since it is used by a minority of models.

@alexpghayes
Copy link
Author

Mostly a note to myself here, but I find this so weird:

loess_fit <- loess(hp ~ mpg, mtcars)
ols_fit <- lm(hp ~ mpg, mtcars)

predict(loess_fit, se = TRUE)
#> $fit
#>           Mazda RX4       Mazda RX4 Wag          Datsun 710 
#>           117.13355           117.13355            94.86926 
#>      Hornet 4 Drive   Hornet Sportabout             Valiant 
#>           111.20545           146.90775           156.01765 
#>          Duster 360           Merc 240D            Merc 230 
#>           221.27945            82.55012            94.86926 
#>            Merc 280           Merc 280C          Merc 450SE 
#>           139.60797           161.06873           191.50926 
#>          Merc 450SL         Merc 450SLC  Cadillac Fleetwood 
#>           171.04944           212.02938           217.53897 
#> Lincoln Continental   Chrysler Imperial            Fiat 128 
#>           217.53897           217.73432            70.34795 
#>         Honda Civic      Toyota Corolla       Toyota Corona 
#>            68.06870            74.30992           109.96664 
#>    Dodge Challenger         AMC Javelin          Camaro Z28 
#>           208.09036           212.02938           227.37086 
#>    Pontiac Firebird           Fiat X1-9       Porsche 914-2 
#>           139.60797            70.92995            75.05325 
#>        Lotus Europa      Ford Pantera L        Ferrari Dino 
#>            68.06870           203.50778           132.83885 
#>       Maserati Bora          Volvo 142E 
#>           214.51379           111.20545 
#> 
#> $se.fit
#>           Mazda RX4       Mazda RX4 Wag          Datsun 710 
#>            12.64145            12.64145            14.01042 
#>      Hornet 4 Drive   Hornet Sportabout             Valiant 
#>            12.69869            13.27396            13.19164 
#>          Duster 360           Merc 240D            Merc 230 
#>            12.17056            14.52914            14.01042 
#>            Merc 280           Merc 280C          Merc 450SE 
#>            13.28114            13.13412            11.23666 
#>          Merc 450SL         Merc 450SLC  Cadillac Fleetwood 
#>            12.44295            11.83314            26.52021 
#> Lincoln Continental   Chrysler Imperial            Fiat 128 
#>            26.52021            12.03407            20.39697 
#>         Honda Civic      Toyota Corolla       Toyota Corona 
#>            15.66029            29.16158            12.74385 
#>    Dodge Challenger         AMC Javelin          Camaro Z28 
#>            11.71555            11.83314            12.96669 
#>    Pontiac Firebird           Fiat X1-9       Porsche 914-2 
#>            13.28114            15.92809            15.54420 
#>        Lotus Europa      Ford Pantera L        Ferrari Dino 
#>            15.66029            11.53693            12.55625 
#>       Maserati Bora          Volvo 142E 
#>            11.92343            12.69869 
#> 
#> $residual.scale
#> [1] 38.82655
#> 
#> $df
#> [1] 26.40844
predict(ols_fit, se.fit = TRUE)
#> $fit
#>           Mazda RX4       Mazda RX4 Wag          Datsun 710 
#>           138.65796           138.65796           122.76445 
#>      Hornet 4 Drive   Hornet Sportabout             Valiant 
#>           135.12607           158.96634           164.26418 
#>          Duster 360           Merc 240D            Merc 230 
#>           197.81716           108.63688           122.76445 
#>            Merc 280           Merc 280C          Merc 450SE 
#>           154.55148           166.91310           179.27473 
#>          Merc 450SL         Merc 450SLC  Cadillac Fleetwood 
#>           171.32797           189.87040           232.25311 
#> Lincoln Continental   Chrysler Imperial            Fiat 128 
#>           232.25311           194.28527            37.99903 
#>         Honda Civic      Toyota Corolla       Toyota Corona 
#>            55.65849            24.75443           134.24310 
#>    Dodge Challenger         AMC Javelin          Camaro Z28 
#>           187.22148           189.87040           206.64689 
#>    Pontiac Firebird           Fiat X1-9       Porsche 914-2 
#>           154.55148            83.03066            94.50931 
#>        Lotus Europa      Ford Pantera L        Ferrari Dino 
#>            55.65849           184.57256           150.13661 
#>       Maserati Bora          Volvo 142E 
#>           191.63635           135.12607 
#> 
#> $se.fit
#>  [1]  7.859249  7.859249  8.540431  7.955493  7.979104  8.194232 10.856161
#>  [8]  9.602008  8.540431  7.855566  8.327554  9.149276  8.585183 10.068247
#> [15] 14.879628 14.879628 10.496944 17.894399 15.576477 19.682493  7.984744
#> [22]  9.823006 10.068247 11.808185  7.855566 12.226508 10.965355 15.576477
#> [29]  9.587597  7.785322 10.236854  7.955493
#> 
#> $df
#> [1] 30
#> 
#> $residual.scale
#> [1] 43.94526

Created on 2018-08-03 by the reprex
package
(v0.2.0).

i.e. into you get lists rather than data frames, with inconsistent arguments and column naming (se and se.fit here). This is easy to get into a nice format, I just find the difference between predict(ols_fit) and predict(ols_fit, se.fit = TRUE) rather striking.

Do most modelling packages follow this convention??

@topepo
Copy link
Member

topepo commented Aug 3, 2018

It's really bad within glmnet too where the object types produced by predict can range wildly. I was going to write a function today just to do that.

@topepo
Copy link
Member

topepo commented Sep 18, 2018

This is now implemented in the latest commit for glm, stan, and ranger models:

> library(parsnip)
> library(tidymodels)
── Attaching packages ────────────────────────────────────────────────────────────────────────────────────────────────── tidymodels 0.0.1.9000 ──
✔ ggplot2   3.0.0recipes   0.1.3tibble    1.4.2broom     0.5.0purrr     0.2.5yardstick 0.0.1dplyr     0.7.99.9000infer     0.3.1rsample   0.0.2dials     0.0.1.9000 
── Conflicts ────────────────────────────────────────────────────────────────────────────────────────────────────────── tidymodels_conflicts() ──
✖ scales::discard() masks purrr::discard()
✖ rsample::fill()   masks tidyr::fill()
✖ dplyr::filter()   masks stats::filter()
✖ dplyr::lag()      masks stats::lag()
✖ recipes::step()   masks stats::step()
> 
> set.seed(365)
> att_split <- initial_split(attrition)
> attrition_tr <- training(attrition)
> attrition_te <- testing(attrition)
> 
> attr_mod <-
+   rand_forest(
+     mode = "classification",
+     trees = 2000,
+     others = list(keep.inbag = TRUE, probability = TRUE)
+   ) %>%
+   fit(Attrition ~ Age + Department + YearsInCurrentRole,
+       data = attrition_tr,
+       engine = "ranger")
> 
> 
> predict(attr_mod, new_data = attrition_te, type = "prob") %>%
+   bind_cols(
+     predict(
+       attr_mod,
+       new_data = attrition_te,
+       type = "conf_int",
+       std_error = TRUE
+     )
+   )
# A tibble: 1,470 x 8
   .pred_No .pred_Yes .pred_lower_No .pred_upper_No .pred_lower_Yes .pred_upper_Yes .std_error_No .std_error_Yes
      <dbl>     <dbl>          <dbl>          <dbl>           <dbl>           <dbl>         <dbl>          <dbl>
 1    0.814    0.186           0.746          0.882          0.118            0.254        0.0346         0.0347
 2    0.916    0.0845          0.861          0.970          0.0309           0.138        0.0278         0.0273
 3    0.790    0.210           0.711          0.869          0.129            0.290        0.0403         0.0410
 4    0.809    0.191           0.683          0.935          0.0636           0.318        0.0643         0.0648
 5    0.845    0.155           0.782          0.908          0.0922           0.218        0.0322         0.0321
 6    0.885    0.115           0.830          0.941          0.0600           0.169        0.0283         0.0279
 7    0.814    0.186           0.746          0.882          0.117            0.255        0.0349         0.0352
 8    0.719    0.281           0.625          0.813          0.186            0.377        0.0479         0.0488
 9    0.938    0.0619          0.888          0.989          0.0126           0.111        0.0258         0.0251
10    0.939    0.0614          0.889          0.989          0.0129           0.110        0.0255         0.0248
# ... with 1,460 more rows

@github-actions
Copy link

github-actions bot commented Mar 7, 2021

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.

@github-actions github-actions bot locked and limited conversation to collaborators Mar 7, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants