Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add lightgbm treeSHAP support #16

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 3 additions & 49 deletions .Rproj.user/shared/notebooks/paths
Original file line number Diff line number Diff line change
@@ -1,49 +1,3 @@
/Library/Frameworks/R.framework/Versions/3.5/Resources/library/RcppArmadillo/include/RcppArmadillo.h="7A2EC50"
/Users/b780620/.R/Makevars="2A873AA9"
/Users/b780620/Desktop/devel/fastshap/.Rbuildignore="9D5980A4"
/Users/b780620/Desktop/devel/fastshap/.gitattributes="1590E2AD"
/Users/b780620/Desktop/devel/fastshap/.gitignore="7E490EEC"
/Users/b780620/Desktop/devel/fastshap/.travis.yml="2720CBBB"
/Users/b780620/Desktop/devel/fastshap/DESCRIPTION="AD1517C7"
/Users/b780620/Desktop/devel/fastshap/NAMESPACE="BF528FAD"
/Users/b780620/Desktop/devel/fastshap/NEWS.md="7A48BA2E"
/Users/b780620/Desktop/devel/fastshap/R/RcppExports.R="7EE1E8DB"
/Users/b780620/Desktop/devel/fastshap/R/autoplot.R="8E338D61"
/Users/b780620/Desktop/devel/fastshap/R/explain.R="EEBAA2BE"
/Users/b780620/Desktop/devel/fastshap/R/fastshap-package.R="250FF95F"
/Users/b780620/Desktop/devel/fastshap/R/force_plot.R="DAE77890"
/Users/b780620/Desktop/devel/fastshap/R/gen_friedman.R="E242B3D6"
/Users/b780620/Desktop/devel/fastshap/R/utils.R="DAEA8479"
/Users/b780620/Desktop/devel/fastshap/README.Rmd="9999330F"
/Users/b780620/Desktop/devel/fastshap/TODO.md="111E1592"
/Users/b780620/Desktop/devel/fastshap/_pkgdown.yml="7EA64706"
/Users/b780620/Desktop/devel/fastshap/codecov.yml="A39CFB69"
/Users/b780620/Desktop/devel/fastshap/inst/tinytest/test_copy_classes.R="87DD3522"
/Users/b780620/Desktop/devel/fastshap/inst/tinytest/test_fastshap.R="4B297159"
/Users/b780620/Desktop/devel/fastshap/inst/tinytest/test_fastshap_adjust.R="EF3C4AD4"
/Users/b780620/Desktop/devel/fastshap/inst/tinytest/test_fastshap_ames.R="694D3D80"
/Users/b780620/Desktop/devel/fastshap/inst/tinytest/test_fastshap_exact.R="65A83803"
/Users/b780620/Desktop/devel/fastshap/inst/tinytest/test_fastshap_matrix.R="65E9EB44"
/Users/b780620/Desktop/devel/fastshap/inst/tinytest/test_fastshap_titanic.R="DE7B49F6"
/Users/b780620/Desktop/devel/fastshap/inst/tinytest/test_force_plot.R="55C2401A"
/Users/b780620/Desktop/devel/fastshap/rjournal/greenwell.Rmd="5B4336F7"
/Users/b780620/Desktop/devel/fastshap/slowtests/fastshap-ames.R="138A8419"
/Users/b780620/Desktop/devel/fastshap/slowtests/fastshap-benchmarks.R="C70C0C1C"
/Users/b780620/Desktop/devel/fastshap/slowtests/fastshap-boston.R="47C5AAA8"
/Users/b780620/Desktop/devel/fastshap/slowtests/fastshap-cpp-benchmarks.R="40E7B57E"
/Users/b780620/Desktop/devel/fastshap/slowtests/fastshap-genOMat.cpp="8F2A7303"
/Users/b780620/Desktop/devel/fastshap/slowtests/fastshap-matrix.R="AB254C9"
/Users/b780620/Desktop/devel/fastshap/slowtests/fastshap-par_replicate.R="9F6C61BF"
/Users/b780620/Desktop/devel/fastshap/slowtests/fastshap-xgboost.R="E70DBEEE"
/Users/b780620/Desktop/devel/fastshap/slowtests/slowtests-boston.html="4CB2AB05"
/Users/b780620/Desktop/devel/fastshap/src/.gitignore="C52AAA53"
/Users/b780620/Desktop/devel/fastshap/src/Makevars="805EA0CF"
/Users/b780620/Desktop/devel/fastshap/src/Makevars.win="C3DD2D34"
/Users/b780620/Desktop/devel/fastshap/src/RcppExports.cpp="9D414F34"
/Users/b780620/Desktop/devel/fastshap/src/fastshap.cpp="44558692"
/Users/b780620/Desktop/devel/fastshap/tests/tinytest.R="4C0F8E7D"
/Users/b780620/Desktop/devel/fastshap/tools/logo-fastshap.R="44537B4"
/Users/b780620/Desktop/devel/training/training-datarobot-api/docs/api-leaderboard.Rmd="AD7E5B2B"
/Users/b780620/Desktop/devel/training/training-datarobot-api/docs/scripts/setup-cache.R="C3F54CAC"
/Users/b780620/Desktop/devel/vip/R/gen_friedman.R="A1AA2FC7"
/Users/b780620/Desktop/devel/vip/R/vi_permute.R="DA45DF66"
/Users/bgreenwell/Dropbox/devel/fastshap/R/gen_pkg_bib.R="E1B8FA98"
/Users/bgreenwell/Dropbox/devel/fastshap/rjournal/greenwell.Rmd="44C6DFCC"
/Users/bgreenwell/Dropbox/devel/fastshap/rjournal/greenwell.bib="6C1920D6"
7 changes: 4 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: fastshap
Type: Package
Title: Fast Approximate Shapley Values
Version: 0.0.5
Version: 0.0.6
Authors@R: person("Brandon", "Greenwell", email = "greenwell.brandon@gmail.com", role =
c("aut", "cre"), comment = c(ORCID = "0000-0002-8120-0084"))
Description: Computes fast (relative to other implementations) approximate
Expand Down Expand Up @@ -32,8 +32,9 @@ Suggests:
rstudioapi,
tinytest,
titanic,
xgboost
xgboost,
lightgbm
LinkingTo:
Rcpp,
RcppArmadillo
RoxygenNote: 7.0.2
RoxygenNote: 7.1.1
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# fastshap 0.0.6

## Enhancements

* Function `explain()` with `exact = TRUE` has received a method for LightGBM models.

# fastshap 0.0.5

## Bug fixes
Expand Down
35 changes: 30 additions & 5 deletions R/explain.R
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ explain_column <- function(object, X, column, pred_wrapper, newdata = NULL) {
#'
#' Compute fast (approximate) Shapley values for a set of features.
#'
#' @param object A fitted model object (e.g., a \code{\link[ranger]{ranger}} or
#' an \code{\link[xgboost]{xgboost}} object).
#' @param object A fitted model object (e.g., a \code{\link[ranger]{ranger}},
#' an \code{\link[xgboost]{xgboost}}, or a \code{\link[lightgbm]{lightgbm}} object).
#'
#' @param feature_names Character string giving the names of the predictor
#' variables (i.e., features) of interest. If \code{NULL} (default) they will be
Expand Down Expand Up @@ -125,9 +125,9 @@ explain_column <- function(object, X, column, pred_wrapper, newdata = NULL) {
#' training data (i.e., \code{X}).
#'
#' @param exact Logical indicating whether to compute exact Shapley values.
#' Currently only available for \code{\link[stats]{lm}} and
#' \code{\link[xgboost]{xgboost}} objects. Default is \code{FALSE}. Note
#' that setting \code{exact = TRUE} will return explanations for each of the
#' Currently only available for \code{\link[stats]{lm}},
#' \code{\link[xgboost]{xgboost}}, and \code{\link[lightgbm]{lightgbm}} objects.
#' Default is \code{FALSE}. Note that setting \code{exact = TRUE} will return explanations for each of the
#' \code{\link[stats]{terms}} in an \code{\link[stats]{lm}} object.
#'
#' @param ... Additional optional arguments to be passed on to
Expand Down Expand Up @@ -366,3 +366,28 @@ explain.xgb.Booster <- function(object, feature_names = NULL, X = NULL, nsim = 1
pred_wrapper = pred_wrapper, newdata = newdata, ...)
}
}

#' @rdname explain
#'
#' @export
explain.lgb.Booster <- function(object, feature_names = NULL, X = NULL, nsim = 1,
pred_wrapper, newdata = NULL, exact = FALSE,
...) {
if (isTRUE(exact)) { # use TreeSHAP
if (is.null(X) && is.null(newdata)) {
stop("Must supply `X` or `newdata` argument (but not both).",
call. = FALSE)
}
X <- if (is.null(X)) newdata else X
res <- stats::predict(object, data = X, predcontrib = TRUE, ...)
colnames(res) <- c(colnames(X), "BIAS")
res <- tibble::as_tibble(res)
attr(res, which = "baseline") <- res[["BIAS"]]
res[["BIAS"]] <- NULL
class(res) <- c(class(res), "explain")
res
} else {
explain.default(object, feature_names = feature_names, X = X, nsim = nsim,
pred_wrapper = pred_wrapper, newdata = newdata, ...)
}
}
38 changes: 38 additions & 0 deletions R/gen_pkg_bib.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# # Grab cited packages from Rmd file
# lines <- readLines("rjournal/greenwell-boehmke.Rmd")
# z <- sapply(lines, FUN = function(x) {
# stringi::stri_extract(x, pattern = "pkg\\{[:alnum:]*\\}", regex = TRUE)
# })
# z <- unname(sort(unique(z)))
# z <- gsub("^pkg\\{", replacement = "", x = z)
# z <- gsub("\\}$", replacement = "", x = z)

# Remove current bib files, if they exist
files <- c("rjournal/greenwell.bib", "rjournal/packages.bib")
for (f in files) {
if (file.exists(f)) {
file.remove(f)
}
}

# List of cited packages to include in the bibliography
pkgs <- c(
"iBreakDown",
"iml",
"fastshap",
"Rcpp",
"reticulate",
"SHAPforxgboost",
"shapper"
)

# Make sure the packages listed above are installed and up to date
required_pkgs <- setdiff(pkgs, installed.packages()[, "Package"])
install.packages(required_pkgs)

# Generate bibliography
knitr::write_bib(pkgs, file = "rjournal/greenwell.bib", tweak = TRUE,
width = NULL, prefix = "R-")

# Create new bib file
file.append("rjournal/greenwell.bib", "rjournal/general.bib")
Loading