-
Notifications
You must be signed in to change notification settings - Fork 90
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reporting uncertainty #37
Comments
Right now, there are functions to make predictions in different context (e.g. classes, probabilities etc) and a print method will bind them all via a
Those are specific to several models. My thoughts so far is to have a It would be good to have some harmonization (maybe
I don't think so since it is used by a minority of models. |
Mostly a note to myself here, but I find this so weird: loess_fit <- loess(hp ~ mpg, mtcars)
ols_fit <- lm(hp ~ mpg, mtcars)
predict(loess_fit, se = TRUE)
#> $fit
#> Mazda RX4 Mazda RX4 Wag Datsun 710
#> 117.13355 117.13355 94.86926
#> Hornet 4 Drive Hornet Sportabout Valiant
#> 111.20545 146.90775 156.01765
#> Duster 360 Merc 240D Merc 230
#> 221.27945 82.55012 94.86926
#> Merc 280 Merc 280C Merc 450SE
#> 139.60797 161.06873 191.50926
#> Merc 450SL Merc 450SLC Cadillac Fleetwood
#> 171.04944 212.02938 217.53897
#> Lincoln Continental Chrysler Imperial Fiat 128
#> 217.53897 217.73432 70.34795
#> Honda Civic Toyota Corolla Toyota Corona
#> 68.06870 74.30992 109.96664
#> Dodge Challenger AMC Javelin Camaro Z28
#> 208.09036 212.02938 227.37086
#> Pontiac Firebird Fiat X1-9 Porsche 914-2
#> 139.60797 70.92995 75.05325
#> Lotus Europa Ford Pantera L Ferrari Dino
#> 68.06870 203.50778 132.83885
#> Maserati Bora Volvo 142E
#> 214.51379 111.20545
#>
#> $se.fit
#> Mazda RX4 Mazda RX4 Wag Datsun 710
#> 12.64145 12.64145 14.01042
#> Hornet 4 Drive Hornet Sportabout Valiant
#> 12.69869 13.27396 13.19164
#> Duster 360 Merc 240D Merc 230
#> 12.17056 14.52914 14.01042
#> Merc 280 Merc 280C Merc 450SE
#> 13.28114 13.13412 11.23666
#> Merc 450SL Merc 450SLC Cadillac Fleetwood
#> 12.44295 11.83314 26.52021
#> Lincoln Continental Chrysler Imperial Fiat 128
#> 26.52021 12.03407 20.39697
#> Honda Civic Toyota Corolla Toyota Corona
#> 15.66029 29.16158 12.74385
#> Dodge Challenger AMC Javelin Camaro Z28
#> 11.71555 11.83314 12.96669
#> Pontiac Firebird Fiat X1-9 Porsche 914-2
#> 13.28114 15.92809 15.54420
#> Lotus Europa Ford Pantera L Ferrari Dino
#> 15.66029 11.53693 12.55625
#> Maserati Bora Volvo 142E
#> 11.92343 12.69869
#>
#> $residual.scale
#> [1] 38.82655
#>
#> $df
#> [1] 26.40844
predict(ols_fit, se.fit = TRUE)
#> $fit
#> Mazda RX4 Mazda RX4 Wag Datsun 710
#> 138.65796 138.65796 122.76445
#> Hornet 4 Drive Hornet Sportabout Valiant
#> 135.12607 158.96634 164.26418
#> Duster 360 Merc 240D Merc 230
#> 197.81716 108.63688 122.76445
#> Merc 280 Merc 280C Merc 450SE
#> 154.55148 166.91310 179.27473
#> Merc 450SL Merc 450SLC Cadillac Fleetwood
#> 171.32797 189.87040 232.25311
#> Lincoln Continental Chrysler Imperial Fiat 128
#> 232.25311 194.28527 37.99903
#> Honda Civic Toyota Corolla Toyota Corona
#> 55.65849 24.75443 134.24310
#> Dodge Challenger AMC Javelin Camaro Z28
#> 187.22148 189.87040 206.64689
#> Pontiac Firebird Fiat X1-9 Porsche 914-2
#> 154.55148 83.03066 94.50931
#> Lotus Europa Ford Pantera L Ferrari Dino
#> 55.65849 184.57256 150.13661
#> Maserati Bora Volvo 142E
#> 191.63635 135.12607
#>
#> $se.fit
#> [1] 7.859249 7.859249 8.540431 7.955493 7.979104 8.194232 10.856161
#> [8] 9.602008 8.540431 7.855566 8.327554 9.149276 8.585183 10.068247
#> [15] 14.879628 14.879628 10.496944 17.894399 15.576477 19.682493 7.984744
#> [22] 9.823006 10.068247 11.808185 7.855566 12.226508 10.965355 15.576477
#> [29] 9.587597 7.785322 10.236854 7.955493
#>
#> $df
#> [1] 30
#>
#> $residual.scale
#> [1] 43.94526 Created on 2018-08-03 by the reprex i.e. into you get lists rather than data frames, with inconsistent arguments and column naming ( Do most modelling packages follow this convention?? |
It's really bad within |
This is now implemented in the latest commit for > library(parsnip)
> library(tidymodels)
── Attaching packages ────────────────────────────────────────────────────────────────────────────────────────────────── tidymodels 0.0.1.9000 ──
✔ ggplot2 3.0.0 ✔ recipes 0.1.3
✔ tibble 1.4.2 ✔ broom 0.5.0
✔ purrr 0.2.5 ✔ yardstick 0.0.1
✔ dplyr 0.7.99.9000 ✔ infer 0.3.1
✔ rsample 0.0.2 ✔ dials 0.0.1.9000
── Conflicts ────────────────────────────────────────────────────────────────────────────────────────────────────────── tidymodels_conflicts() ──
✖ scales::discard() masks purrr::discard()
✖ rsample::fill() masks tidyr::fill()
✖ dplyr::filter() masks stats::filter()
✖ dplyr::lag() masks stats::lag()
✖ recipes::step() masks stats::step()
>
> set.seed(365)
> att_split <- initial_split(attrition)
> attrition_tr <- training(attrition)
> attrition_te <- testing(attrition)
>
> attr_mod <-
+ rand_forest(
+ mode = "classification",
+ trees = 2000,
+ others = list(keep.inbag = TRUE, probability = TRUE)
+ ) %>%
+ fit(Attrition ~ Age + Department + YearsInCurrentRole,
+ data = attrition_tr,
+ engine = "ranger")
>
>
> predict(attr_mod, new_data = attrition_te, type = "prob") %>%
+ bind_cols(
+ predict(
+ attr_mod,
+ new_data = attrition_te,
+ type = "conf_int",
+ std_error = TRUE
+ )
+ )
# A tibble: 1,470 x 8
.pred_No .pred_Yes .pred_lower_No .pred_upper_No .pred_lower_Yes .pred_upper_Yes .std_error_No .std_error_Yes
<dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 0.814 0.186 0.746 0.882 0.118 0.254 0.0346 0.0347
2 0.916 0.0845 0.861 0.970 0.0309 0.138 0.0278 0.0273
3 0.790 0.210 0.711 0.869 0.129 0.290 0.0403 0.0410
4 0.809 0.191 0.683 0.935 0.0636 0.318 0.0643 0.0648
5 0.845 0.155 0.782 0.908 0.0922 0.218 0.0322 0.0321
6 0.885 0.115 0.830 0.941 0.0600 0.169 0.0283 0.0279
7 0.814 0.186 0.746 0.882 0.117 0.255 0.0349 0.0352
8 0.719 0.281 0.625 0.813 0.186 0.377 0.0479 0.0488
9 0.938 0.0619 0.888 0.989 0.0126 0.111 0.0258 0.0251
10 0.939 0.0614 0.889 0.989 0.0129 0.110 0.0255 0.0248
# ... with 1,460 more rows |
This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue. |
I'm reworking lots of
broom::augment()
methods at the moment and am discovering that packages do some crazy stuff to report uncertainty. Defining some standards for reporting uncertainty early on seems like a good idea.For classification problems, reporting the class probabilities makes sense, but can become problematic for outcomes with high cardinalities. Nobody wants 1000 columns of class probabilities. One option is to just report the most likely class along with it's probability, or the top
k = 5
or so classes by default.For regression problems I think there's more nuance. Open questions:
se_fit
or similar?The text was updated successfully, but these errors were encountered: