Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R-package] Output structure of predict() #5223

Closed
mayer79 opened this issue May 18, 2022 · 4 comments
Closed

[R-package] Output structure of predict() #5223

mayer79 opened this issue May 18, 2022 · 4 comments

Comments

@mayer79
Copy link
Contributor

mayer79 commented May 18, 2022

How do I read the output of predict() in multiclass settings and R for a specific example with k = 2 trees, m = 3 classes and p = 4 features?

  • predict(...): Clear, one column per class. Just to get warm.
  • predict(..., predleaf = TRUE): The first k columns give me the tree node indices for the first class. The next k columns those for the second class etc? Or is the order different?
  • predict(..., predcontrib = TRUE): The first p + 1 columns in the output give me the SHAP values of the p features (and a BIAS) for the first class etc?

Thus, we always get all results for the first class, then those for the second class etc?

library(lightgbm)

X <- data.matrix(iris[, -5])
y <- -1L + as.integer(iris[[5]])

fit <- lgb.train(
  params = list(objective = "multiclass", num_class = 3)
  , data = lgb.Dataset(X, label = y)
  , nrounds = 2
  , verbose = -2
)

predict(fit, head(X))
#           [,1]      [,2]      [,3]
# [1,] 0.4671195 0.2663959 0.2664846
# [2,] 0.4603037 0.2698664 0.2698299
# [3,] 0.4603037 0.2698664 0.2698299
# [4,] 0.4603998 0.2697140 0.2698863
# [5,] 0.4671195 0.2663959 0.2664846
# [6,] 0.4672158 0.2662447 0.2665395

predict(fit, head(X), predleaf = TRUE)
#      [,1] [,2] [,3] [,4] [,5] [,6]
# [1,]    2    0    0    2    0    5
# [2,]    0    0    0    0    0    3
# [3,]    0    0    0    0    0    3
# [4,]    0    0    0    0    5    3
# [5,]    2    0    0    2    0    5
# [6,]    2    0    4    2    5    5

predict(fit, head(X), predcontrib = TRUE)
#       [,1]         [,2]      [,3]          [,4]      [,5] [,6] [,7]       [,8]        [,9]     [,10] [,11]
# [1,]    0  0.007262867 0.3671346 -3.332467e-05 -1.106534    0    0 -0.1354436 -0.05341171 -1.104917     0
# [2,]    0 -0.011257445 0.3580126 -3.332467e-05 -1.106534    0    0 -0.1354436 -0.05341171 -1.104917     0
# [3,]    0 -0.011257445 0.3580126 -3.332467e-05 -1.106534    0    0 -0.1354436 -0.05341171 -1.104917     0
# [4,]    0 -0.011257445 0.3580126 -3.332467e-05 -1.106534    0    0 -0.1359543 -0.05367475 -1.104917     0
# [5,]    0  0.007262867 0.3671346 -3.332467e-05 -1.106534    0    0 -0.1354436 -0.05341171 -1.104917     0
# [6,]    0  0.007262867 0.3671346 -3.332467e-05 -1.106534    0    0 -0.1359543 -0.05367475 -1.104917     0
#             [,12]       [,13]       [,14]     [,15]
# [1,] -0.002416721 -0.09863067 -0.08731765 -1.105074
# [2,] -0.002757015 -0.09875858 -0.08731765 -1.105074
# [3,] -0.002757015 -0.09875858 -0.08731765 -1.105074
# [4,] -0.002757015 -0.09875858 -0.08731765 -1.105074
# [5,] -0.002416721 -0.09863067 -0.08731765 -1.105074
# [6,] -0.002416721 -0.09863067 -0.08731765 -1.105074
@jameslamb jameslamb changed the title [R package] [Question] Output structure of predict [R-package] Output structure of predict() Jul 17, 2022
@jameslamb
Copy link
Collaborator

jameslamb commented Jul 19, 2022

Hey @mayer79 , sorry it took so long to get back to you! 😆 I laughed at "just to get warm".

And thanks very much for the small reproducible example. Made it much easy to understand exactly what you were asking and how to investigate it.

It took me a while to respond because I wanted to prove to myself that I had the right understanding.

I believe that it's:

  • "leaf" = ordered by tree (see examples below)
  • "contrib" = ordered by class (all contributions for class 1, followed by all for class 2, etc.)

First, for the SHAP values. Since SHAP values should sum to the raw prediction, you can confirm that the row-wise sums of each contiguous block matches the "raw score" predictions.

num_features <- ncol(X)
num_shap_cols_per_class <- num_features + 1
preds_contrib <- predict(fit, head(X), type = "contrib")
preds_raw <- predict(fit, head(X), type = "raw")

# class 1
rowSums(preds_contrib[, 1:num_shap_cols_per_class])
# [1] -0.7321704 -0.7598126 -0.7598126 -0.7598126 -0.7321704 -0.7321704
preds_raw[, 1]
# [1] -0.7321704 -0.7598126 -0.7598126 -0.7598126 -0.7321704 -0.7321704

# class 2
rowSums(preds_contrib[, (num_shap_cols_per_class + 1):(num_shap_cols_per_class * 2)])
# [1] -1.293772 -1.293772 -1.293772 -1.294546 -1.293772 -1.294546
preds_raw[, 2]
# [1] -1.293772 -1.293772 -1.293772 -1.294546 -1.293772 -1.294546

# class 3
rowSums(preds_contrib[, (num_shap_cols_per_class * 2 + 1):(num_shap_cols_per_class * 3)])
# [1] -1.293439 -1.293907 -1.293907 -1.293907 -1.293439 -1.293439
preds_raw[, 3]
# [1] -1.293439 -1.293907 -1.293907 -1.293907 -1.293439 -1.293439

The leaf multiclass predictions in the R package don't follow that same pattern. Those are ordered by tree.

library(data.table)

preds_raw <- predict(fit, head(X), type = "raw")
preds_leaf <- predict(fit, head(X), type = "leaf")

# with three classes, each boosting iteration produces 3 trees
# this is why the tree indices (integer unique IDs) skip by 3
# 
# * first class: 0, 3
# * second class: 1, 4
# * third class: 2, 5
#
# create a table mapping those to the leaf indices produced by predict(..., type = "leaf")
firstRowDT <- data.table::data.table(
    target_class = c(1, 2, 3, 1, 2, 3)
    , tree_index = c(0, 1, 2, 3, 4, 5)
    , leaf_index = preds_leaf[1, ]
)
#    target_class tree_index leaf_index
# 1:            1          0          2
# 2:            1          3          0
# 3:            2          1          0
# 4:            2          4          2
# 5:            3          2          0
# 6:            3          5          5

# next, dump the model information, which describes every node
# (including its predicted value for samples falling into it, and its tree_index and leaf_index)
modelDT <- lightgbm::lgb.model.dt.tree(fit)

# join them together
joinedDT <- merge(
    x = firstRowDT
    , y = modelDT[, .(tree_index, leaf_index, leaf_value)]
    , by = c("tree_index", "leaf_index")
    , all.x = TRUE
)
joinedDT[, .(target_class, tree_index, leaf_index, leaf_value)]
#    target_class tree_index leaf_index  leaf_value
# 1:            1          0          2 -0.89861229
# 2:            2          1          0 -1.19861229
# 3:            3          2          0 -1.19861229
# 4:            1          3          0  0.15379967
# 5:            2          4          2 -0.06515416
# 6:            3          5          5 -0.09482663

# get predictions by summing leaf values from trees belonging to the same class
joinedDT[
    , sum(leaf_value)
    , by = target_class
][["V1"]]
# [1] -0.7321704 -1.2937720 -1.2934389
preds_raw[1, ]
# [1] -0.7321704 -1.2937720 -1.2934389

I know that example is a bit complicated, but the core idea is this... I grabbed the first row of the leaf predictions, did not reorder them, and then hard-coded an assertion about which class and tree index each element belong to. If that mapping wasn't correct, the code I've provided wouldn't produce exactly the same values as predict(..., type = "raw").

If you have time and interest, we'd welcome a contribution adding a note to

#' @name predict.lgb.Booster
describing this. I think it would be a great addition as a standalone note (instead of intermingled in the type documentation).

If not, then just let me know and I'd be happy to add that to the docs.

@mayer79
Copy link
Contributor Author

mayer79 commented Jul 19, 2022

Ingenious! Thanks a lot. I think we should indeed explain this shortly in the help page.

@mayer79 mayer79 closed this as completed Jul 19, 2022
@jameslamb
Copy link
Collaborator

I'll make this addition to the docs later today.

@github-actions
Copy link

This issue has been automatically locked since there has not been any recent activity since it was closed. To start a new related discussion, open a new issue at https://github.com/microsoft/LightGBM/issues including a reference to this.

@github-actions github-actions bot locked as resolved and limited conversation to collaborators Aug 19, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

2 participants