Skip to content

Commit dafe2d7

Browse files
committed
Updated GBM ALE objects to respond to ggplot2 3.5.0 (#2).
1 parent bed304a commit dafe2d7

6 files changed

+17
-16
lines changed

download/gbm.data_model.rds

-1.68 KB
Binary file not shown.

download/gbm_ale_ixn_link.rds

5.12 MB
Binary file not shown.

download/gbm_ale_ixn_prob.rds

5.1 MB
Binary file not shown.

download/gbm_ale_link.rds

19.2 MB
Binary file not shown.

download/gbm_ale_prob.rds

19.2 MB
Binary file not shown.

vignettes/ale-ALEPlot.Rmd

+17-16
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,8 @@ data <-
226226
Although gradient boosted trees generally perform quite well, they are rather slow. Rather than having you wait for it to run, the code here downloads a pretrained GBM model. However, the code used to generate it is provided in comments so that you can see it and run it yourself if you want to. Note that the model calls is based on `data[,-c(3,4)]`, which drops the third and fourth variables (`fnlwgt` and `education`, respectively).
227227

228228
```{r gbm model}
229-
# To generate the code, uncomment the following lines.
230-
# But it is slow, so this vignette loads a pre-created model object.
229+
# # To generate the code, uncomment the following lines.
230+
# # But they are slow, so this vignette loads a pre-created model object.
231231
# set.seed(0)
232232
# gbm.data <- gbm(higher_income ~ ., data= data[,-c(3,4)],
233233
# distribution = "bernoulli", n.trees=6000, shrinkage=0.02,
@@ -274,19 +274,18 @@ We display all the plots because it is easy to do so with the `{ale}` package bu
274274

275275
```{r ale one-way link, fig.width=7, fig.height=20}
276276
# Custom predict function that returns log odds
277-
yhat <- function(object, newdata) {
278-
as.numeric(
279-
predict(object, newdata, n.trees = 6000,
280-
type="link") # return log odds
281-
)
277+
yhat <- function(object, newdata, type) {
278+
predict(object, newdata, type='link', n.trees = 6000) |> # return log odds
279+
as.numeric()
282280
}
283281
284282
# Generate ALE data for all variables
285283
286284
# # To generate the code, uncomment the following lines.
287285
# # But it is slow, so this vignette loads a pre-created model object.
288286
# gbm_ale_link <- ale(
289-
# data[,-c(3,4)], gbm.data,
287+
# # data[,-c(3,4)], gbm.data,
288+
# data, gbm.data,
290289
# pred_fun = yhat,
291290
# x_intervals = 500,
292291
# rug_sample_size = 600, # technical issue: rug_sample_size must be > x_intervals + 1
@@ -307,13 +306,13 @@ Now we generate ALE data for all two-way interactions and then plot them. Again,
307306
# # To generate the code, uncomment the following lines.
308307
# # But it is slow, so this vignette loads a pre-created model object.
309308
# gbm_ale_ixn_link <- ale_ixn(
310-
# data[,-c(3,4)], gbm.data,
309+
# # data[,-c(3,4)], gbm.data,
310+
# data, gbm.data,
311311
# pred_fun = yhat,
312312
# x_intervals = 500,
313313
# rug_sample_size = 600, # technical issue: rug_sample_size must be > x_intervals + 1
314314
# relative_y = 'zero', # compatibility with ALEPlot
315315
# model_packages = 'gbm' # required for parallel processing
316-
317316
# )
318317
# saveRDS(gbm_ale_ixn_link, file.choose())
319318
gbm_ale_ixn_link <- url('https://github.com/Tripartio/ale/raw/main/download/gbm_ale_ixn_link.rds') |>
@@ -342,7 +341,7 @@ As we can see, the shapes of the plots are similar, but the y axes are more easi
342341

343342
```{r ale one-way prob, fig.width=7, fig.height=20}
344343
# Custom predict function that returns predicted probabilities
345-
yhat <- function(object, newdata) {
344+
yhat <- function(object, newdata, type) {
346345
as.numeric(
347346
predict(object, newdata, n.trees = 6000,
348347
type="response") # return predicted probabilities
@@ -354,11 +353,12 @@ yhat <- function(object, newdata) {
354353
# # To generate the code, uncomment the following lines.
355354
# # But it is slow, so this vignette loads a pre-created model object.
356355
# gbm_ale_prob <- ale(
357-
# data[,-c(3,4)], gbm.data,
356+
# # data[,-c(3,4)], gbm.data,
357+
# data, gbm.data,
358358
# pred_fun = yhat,
359359
# x_intervals = 500,
360360
# rug_sample_size = 600, # technical issue: rug_sample_size must be > x_intervals + 1
361-
# model_packages = 'nnet' # required for parallel processing
361+
# model_packages = 'gbm' # required for parallel processing
362362
# )
363363
# saveRDS(gbm_ale_prob, file.choose())
364364
gbm_ale_prob <- url('https://github.com/Tripartio/ale/raw/main/download/gbm_ale_prob.rds') |>
@@ -371,10 +371,11 @@ gridExtra::grid.arrange(grobs = gbm_ale_prob$plots, ncol = 2)
371371
Finally, we again generate two-way interactions, this time based on probabilities instead of on log odds. However, probabilities might not be the best choice for indicating interactions because, as we see from the rugs in the one-way ALE plots, the GBM model heavily concentrates its probabilities in the extremes near 0 and 1. Thus, the plots' suggestions of strong interactions are likely exaggerated. In this case, the log odds ALEs shown above are probably more relevant.
372372

373373
```{r ale ixn prob, fig.width=7, fig.height=5}
374-
# # To generate the code, uncomment the following lines.
375-
# # But it is slow, so this vignette loads a pre-created model object.
374+
# To generate the code, uncomment the following lines.
375+
# But it is slow, so this vignette loads a pre-created model object.
376376
# gbm_ale_ixn_prob <- ale_ixn(
377-
# data[,-c(3,4)], gbm.data,
377+
# # data[,-c(3,4)], gbm.data,
378+
# data, gbm.data,
378379
# pred_fun = yhat,
379380
# x_intervals = 500,
380381
# rug_sample_size = 600, # technical issue: rug_sample_size must be > x_intervals + 1

0 commit comments

Comments
 (0)