Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

partial() not working correctly for H2O GLM #127

Open
RoelVerbelen opened this issue Nov 3, 2022 · 4 comments
Open

partial() not working correctly for H2O GLM #127

RoelVerbelen opened this issue Nov 3, 2022 · 4 comments

Comments

@RoelVerbelen
Copy link

Oddly, I'm not getting sensible results for GLMs using H2O. Effects for continuous factors are incorrectly looking quadratic.

Here's a simple reprex for a Poisson GLM:

library(tidyverse)
library(finPlot)
library(h2o)

h2o.init()
h2o.no_progress()

df <- h2o.importFile("https://h2o-public-test-data.s3.amazonaws.com/smalldata/prostate/prostate.csv")

predictors <- c("AGE", "RACE", "VOL", "GLEASON")
response <- "CAPSULE"

prostate_glm <- h2o.glm(family = "poisson",
                        link = "log",
                        x = predictors,
                        y = response,
                        training_frame = df,
                        lambda = 0,
                        compute_p_values = TRUE)

# Correct PD using H2O
h2o.partialPlot(prostate_glm, df, "AGE")

# Incorrect PD using pdp
pred.grid <- data.frame(AGE = c(43, 44.8947368421053, 46.7894736842105, 48.6842105263158, 
                                50.5789473684211, 52.4736842105263, 54.3684210526316, 
                                56.2631578947368, 58.1578947368421, 60.0526315789474, 61.9473684210526, 
                                63.8421052631579, 65.7368421052632, 67.6315789473684, 69.5263157894737, 
                                71.4210526315789, 73.3157894736842, 75.2105263157895, 77.1052631578947, 79))

pred.fun <- function(object, newdata) {
  mean(as.numeric(as.vector(h2o.predict(object, as.h2o(newdata)))))
}

pd <- pdp::partial(prostate_glm, 
                   pred.var = "AGE",
                   pred.grid = pred.grid,
                   pred.fun = pred.fun, 
                   train = as.data.frame(df))

autoplot(pd)

Created on 2022-11-03 by the reprex package (v2.0.1)

I'm getting similar issues with Gausian and binomial GLMs. Using version 0.8.1 from CRAN.

@RoelVerbelen
Copy link
Author

Apologies, I found my mistake. I didn't account for the fact that, compared to GBMs, you get standard errors with H2O GLM predictions. Selecting only the first predict column solves it:

pred.fun <- function(object, newdata) {
  mean(as.numeric(as.vector(h2o.predict(object, as.h2o(newdata))[[1]])))
}

pdp is still a lot slower though than h2o.partialPlot() due to multiple predict calls instead of one. The in.memory argument you proposed here still sounds promising.

@bgreenwell
Copy link
Owner

Thanks for the note @RoelVerbelen, glad you found the issue! The bigger bottleneck here is probably attributed to the multiple calls needed to coerce the data with as.h2o, so in cases like this I can see the original in.memory argument providing an advantage. The same idea can be used to compute PD plots in Spark (i.e., with sparkR or sparklyr). In fact, the in.memory argument effectively generalized the following approach to computing PD plots in Spark with sparklyr, which was suggested in this closed issue:

# Load required packages
library(dplyr)
library(pdp)
library(sparklyr)

data(boston, package = "pdp")

sc <- spark_connect(master = 'local')
boston_sc <- copy_to(sc, boston, overwrite = TRUE)
rfo <- boston_sc %>% ml_random_forest(cmedv ~ ., type = "auto")

# Define plotting grid 
df1 <- data.frame(lstat = quantile(boston$lstat, probs = 1:19/20)) %>% 
  copy_to(sc, df = .)

# Remove plotting variable from training data
df2 <- boston %>%
  select(-lstat) %>%
  copy_to(sc, df = .)

# Perform a cross join, compute predictions, then aggregate
par_dep <- df1 %>%
  full_join(df2, by = character()) %>%  # cartesian product
  ml_predict(rfo, dataset = .) %>%
  group_by(lstat) %>%  
  summarize(yhat = mean(prediction)) %>%  # average for partial dependence
  select(lstat, yhat) %>%  # select plotting variables
  arrange(lstat) %>%  # for plotting purposes
  collect()

# Plot results
plot(par_dep, type = "l")

You can try this with your h2o example. If you see a significant improvement, then I'd consider revisiting the idea! (Although, I think the original implementation used data.table to do the joins and aggregations).

@RoelVerbelen
Copy link
Author

Hi @bgreenwell, thanks for the response!

Absolutely, it's the conversion from R to H2O (using as.h2o() each time) that's making it so slow. The actual scoring in H2O is really fast. It's a bottleneck of pdp::partial() for H2O (and Spark) at the moment. The benefit of doing that conversion only once is massive, already on this silly toy example:

# Built-in H2O does not need to convert df to H2O anymore
system.time({
  h2o.partialPlot(prostate_glm, df, "AGE")
})
 user  system elapsed 
 0.07    0.00    1.13 

# pdp does as many H2O conversations as nrow(pred.grid)
system.time({
  pd <- pdp::partial(prostate_glm, 
                     pred.var = "AGE",
                     pred.grid = pred.grid,
                     pred.fun = pred.fun, 
                     train = as.data.frame(df))
})
 user  system elapsed 
11.89    2.50  100.55 

# Not relying on pdp and only converting once
system.time({
  pd_fast <- pred.grid %>% 
    full_join(
      df %>% 
        as.data.frame() %>% 
        select(- AGE),
      by = character()
    ) %>% 
    mutate(yhat = as.vector(h2o.predict(prostate_glm, as.h2o(.))[[1]])) %>% 
    group_by(AGE) %>% 
    summarise(yhat = mean(yhat)) %>% 
    select(AGE, yhat)
})
   user  system elapsed 
   0.56    0.00    4.80 

I love how general pdp is, so having this in.memory option available would be a really powerful improvement.

@bgreenwell
Copy link
Owner

Sorry for the delay @RoelVerbelen, that is a convincing example! I'll reopen the issue, but not sure when I'll get to it. Should be easy to resurrect the old branch and grab the code.

@bgreenwell bgreenwell reopened this Nov 16, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants