From cbb487b0e6ba5e19669b1b97cca4bcdce74bbe05 Mon Sep 17 00:00:00 2001
From: certara-mtalley <150705449+certara-mtalley@users.noreply.github.com>
Date: Wed, 6 Dec 2023 16:42:22 -0800
Subject: [PATCH 1/5] Generate shap plots
Create a standalone function to make SHAP plots. Adjusted mlcovsearch output to facilitate this. Updated README, as well as documentation.
---
NAMESPACE | 4 ++
R/cov_search.R | 70 +++++++++++++-----------------
R/generate_shap_summary_plot.R | 71 +++++++++++++++++++++++++++++++
README.md | 15 +++++++
man/generate_shap_summary_plot.Rd | 70 ++++++++++++++++++++++++++++++
5 files changed, 189 insertions(+), 41 deletions(-)
create mode 100644 R/generate_shap_summary_plot.R
create mode 100644 man/generate_shap_summary_plot.Rd
diff --git a/NAMESPACE b/NAMESPACE
index d22810c..4198ea8 100644
--- a/NAMESPACE
+++ b/NAMESPACE
@@ -3,3 +3,7 @@
S3method(print,mlcov_data)
export(MLCovSearch)
export(generate_residualsplots)
+export(generate_shap_summary_plot)
+import(SHAPforxgboost)
+import(ggplot2)
+import(xgboost)
diff --git a/R/cov_search.R b/R/cov_search.R
index 31a6193..3123449 100644
--- a/R/cov_search.R
+++ b/R/cov_search.R
@@ -199,60 +199,48 @@ MLCovSearch <- function(tab, list_pop_param, cov_continuous, cov_factors, seed =
}
}
- # Initialize an empty list to store the SHAP summary plots
- shap_plots <- list()
-
- # Create a function to generate SHAP summary plots
- generate_shap_summary_plot <- function(xgb_model, X_train, param_name) {
- shap_values <- SHAPforxgboost::shap.values(xgb_model = xgb_model, X_train = X_train)
- shap_long <- SHAPforxgboost::shap.prep(xgb_model = xgb_model, X_train = X_train)
- p <- SHAPforxgboost::shap.plot.summary(shap_long)
- p <- p + ggplot2::ggtitle(param_name)
-
- return(p)
- }
+ # Initialize an empty list to store the SHAP summary data and seed information
+ shap_data <- list()
+ shap_seed <- list()
# Interpretation of Selected covariates Beeswarm Plots
for (i in list_pop_param) {
y_xgb <- log(dat_XGB[, i])
-
- if (is.na(result_ML[i, 1]) == FALSE){
- list_cov <- strsplit(gsub(" ", "", result_ML[i, 1]), ",")
- x.selected_final <- as.matrix(dat_XGB %>% dplyr::select(dplyr::all_of(list_cov[[1]])))
-
- if (length(list_cov[[1]]) != 0 ) {
- xgb.mod_final <- xgboost::xgboost(
- data = x.selected_final,
- label = y_xgb,
- nrounds = 200,
- objective = "reg:squarederror",
- verbose = 0
- )
-
- # Generate SHAP summary plot for the current parameter
- shap_plot <- generate_shap_summary_plot(
- xgb_model = xgb.mod_final,
- X_train = x.selected_final,
- param_name = i
- )
- shap_plots[[i]] <- shap_plot
- }
+
+ if (is.na(result_ML[i, 1]) == FALSE) {
+ list_cov <- strsplit(gsub(" ", "", result_ML[i, 1]), ",")
+ x.selected_final <-
+ as.matrix(dat_XGB %>% dplyr::select(dplyr::all_of(list_cov[[1]])))
+
+ if (length(list_cov[[1]]) != 0) {
+ xgb.mod_final <- xgboost::xgboost(
+ data = x.selected_final,
+ label = y_xgb,
+ nrounds = 200,
+ objective = "reg:squarederror",
+ verbose = 0
+ )
+
+ # Generate SHAP summary plot for the current parameter
+ shap_values <- SHAPforxgboost::shap.values(xgb_model = xgb.mod_final, X_train = x.selected_final)
+ shap_long <- SHAPforxgboost::shap.prep(xgb_model = xgb.mod_final, X_train = x.selected_final)
+
+ # Store shap data and seed
+ shap_data[[i]] <- list(shap_values = shap_values, shap_long = shap_long)
+ shap_seed[[i]] <- .Random.seed
+
+ }
}
}
- combined_plots <- gridExtra::marrangeGrob(grobs = shap_plots,nrow = length(shap_plots),ncol = 1)
-
-
-
-
# Return the result_ML table and the SHAP plots for each parameter
return(
list(
result_ML = result_ML,
result_5folds = result_5folds,
- shap_plots = shap_plots,
list_pop_param = list_pop_param,
- dat_XGB = dat_XGB
+ shap_data = shap_data,
+ shap_seed = shap_seed
) %>% structure(class = "mlcov_data")
)
}
diff --git a/R/generate_shap_summary_plot.R b/R/generate_shap_summary_plot.R
new file mode 100644
index 0000000..f86740f
--- /dev/null
+++ b/R/generate_shap_summary_plot.R
@@ -0,0 +1,71 @@
+#' Generate SHAP Summary Plots
+#'
+#' This function generates SHAP summary plots for the XGBoost model.
+#' @inheritParams SHAPforxgboost::shap.plot.summary
+#' @param data A list containing required data frames and results.
+#' @param title A character string to customize the title
+#' @param title.position A numeric value from 0.0-1.0 to adjust the alignment of the title
+#' @param ylab A character string to customize the y-axis
+#' @param xlab A character string to customize the x-axis
+#' @return A list of ggplot objects, each representing a SHAP summary plot for a different parameter.
+#'
+#' @examples
+#' # Assuming 'data' is a list with necessary components
+#' \dontrun{
+#' plots <- generate_shap_summary_plot(data, title = "Custom Title", y.inter = 0.25, ...)
+#' }
+#'
+#' @import xgboost
+#' @import ggplot2
+#' @import SHAPforxgboost
+#'
+#' @export
+#'
+generate_shap_summary_plot <- function(data,
+ x_bound = NULL,
+ dilute = FALSE,
+ scientific = FALSE,
+ my_format = NULL,
+ min_color_bound = "#FFCC33",
+ max_color_bound = "#6600CC",
+ kind = c("sina", "bar"),
+ title = NULL,
+ title.position = 0,
+ ylab = NULL,
+ xlab = NULL)
+{
+ # Initialize an empty list to store the SHAP summary plots
+ shap_plots <- list()
+
+ for (i in data$list_pop_param) {
+ if (!is.null(data$shap_data[[i]])) {
+ # Set to the seed used when making the data
+ set.seed(data$shap_seed[[i]])
+
+ shap_values <- data$shap_data[[i]]$shap_values
+ shap_long <- data$shap_data[[i]]$shap_long
+
+ # Generate summary plot
+ p <-
+ SHAPforxgboost::shap.plot.summary(
+ shap_long,
+ x_bound = x_bound,
+ dilute = dilute,
+ scientific = scientific,
+ my_format = my_format,
+ min_color_bound = min_color_bound,
+ max_color_bound = max_color_bound,
+ kind = kind
+ )
+ p <- p +
+ ggplot2::ggtitle(ifelse(is.null(title), i, title)) +
+ ggplot2::labs(y = ifelse(is.null(ylab), "SHAP value (impact on model output)", ylab),
+ x = ifelse(is.null(xlab), "", xlab)) +
+ ggplot2::theme(plot.title = ggplot2::element_text(hjust = title.position))
+ shap_plots[[i]] <- p
+ }
+
+ }
+
+ return(shap_plots)
+}
diff --git a/README.md b/README.md
index 719c3e7..72c985e 100644
--- a/README.md
+++ b/README.md
@@ -22,6 +22,21 @@ print(result)
```
+Generate SHAP plots:
+
+```
+generate_shap_summary_plot(
+ result,
+ x_bound = NULL,
+ dilute = FALSE,
+ scientific = FALSE,
+ my_format = NULL,
+ title = NULL,
+ title.position = 0.5,
+ ylab = NULL,
+ xlab = NULL)
+```
+
Generate residual plots:
Cl
diff --git a/man/generate_shap_summary_plot.Rd b/man/generate_shap_summary_plot.Rd
new file mode 100644
index 0000000..045eb6e
--- /dev/null
+++ b/man/generate_shap_summary_plot.Rd
@@ -0,0 +1,70 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/generate_shap_summary_plot.R
+\name{generate_shap_summary_plot}
+\alias{generate_shap_summary_plot}
+\title{Generate SHAP Summary Plots}
+\usage{
+generate_shap_summary_plot(
+ data,
+ x_bound = NULL,
+ dilute = FALSE,
+ scientific = FALSE,
+ my_format = NULL,
+ min_color_bound = "#FFCC33",
+ max_color_bound = "#6600CC",
+ kind = c("sina", "bar"),
+ title = NULL,
+ title.position = 0,
+ ylab = NULL,
+ xlab = NULL
+)
+}
+\arguments{
+\item{data}{A list containing required data frames and results.}
+
+\item{x_bound}{use to set horizontal axis limit in the plot}
+
+\item{dilute}{being numeric or logical (TRUE/FALSE), it aims to help make the test
+plot for large amount of data faster. If dilute = 5 will plot 1/5 of the
+data. If dilute = TRUE or a number, will plot at most half points per
+feature, so the plotting won't be too slow. If you put dilute too high, at
+least 10 points per feature would be kept. If the dataset is too small
+after dilution, will just plot all the data}
+
+\item{scientific}{show the mean|SHAP| in scientific format. If TRUE, label
+format is 0.0E-0, default to FALSE, and the format will be 0.000}
+
+\item{my_format}{supply your own number format if you really want}
+
+\item{min_color_bound}{min color hex code for colormap. Color gradient is
+scaled between min_color_bound and max_color_bound. Default is "#FFCC33".}
+
+\item{max_color_bound}{max color hex code for colormap. Color gradient is
+scaled between min_color_bound and max_color_bound. Default is "#6600CC".}
+
+\item{kind}{By default, a "sina" plot is shown. As an alternative,
+set \code{kind = "bar"} to visualize mean absolute SHAP values as a
+barplot. Its color is controlled by \code{max_color_bound}. Other
+arguments are ignored for this kind of plot.}
+
+\item{title}{A character string to customize the title}
+
+\item{title.position}{A numeric value from 0.0-1.0 to adjust the alignment of the title}
+
+\item{ylab}{A character string to customize the y-axis}
+
+\item{xlab}{A character string to customize the x-axis}
+}
+\value{
+A list of ggplot objects, each representing a SHAP summary plot for a different parameter.
+}
+\description{
+This function generates SHAP summary plots for the XGBoost model.
+}
+\examples{
+# Assuming 'data' is a list with necessary components
+\dontrun{
+plots <- generate_shap_summary_plot(data, title = "Custom Title", y.inter = 0.25, ...)
+}
+
+}
From 6770b63eef55fcfdfbba6ae191f484db4827f35a Mon Sep 17 00:00:00 2001
From: certara-mtalley <150705449+certara-mtalley@users.noreply.github.com>
Date: Wed, 6 Dec 2023 16:43:14 -0800
Subject: [PATCH 2/5] Remove mlcovsearch shap plots
This part of the code was placed into a new function, so tests were updated to reflect this.
---
.../_snaps/mlcovsearch/shap-plots.svg | 431 ------------------
tests/testthat/test-mlcovsearch.R | 6 -
2 files changed, 437 deletions(-)
delete mode 100644 tests/testthat/_snaps/mlcovsearch/shap-plots.svg
diff --git a/tests/testthat/_snaps/mlcovsearch/shap-plots.svg b/tests/testthat/_snaps/mlcovsearch/shap-plots.svg
deleted file mode 100644
index 7b7c2b9..0000000
--- a/tests/testthat/_snaps/mlcovsearch/shap-plots.svg
+++ /dev/null
@@ -1,431 +0,0 @@
-
-
diff --git a/tests/testthat/test-mlcovsearch.R b/tests/testthat/test-mlcovsearch.R
index 94db71f..c9a7564 100644
--- a/tests/testthat/test-mlcovsearch.R
+++ b/tests/testthat/test-mlcovsearch.R
@@ -39,12 +39,6 @@ testthat::test_that("MLCovSearch result_5folds is a data.frame with 5 non-NA cov
testthat::expect_true(sum(!is.na(result$result_5folds[1,])) == 5)
})
-# This test will need removed when shap plots are removed from MLCovSearch
-testthat::test_that("MLCovSearch shap plots do not change", {
- testthat::skip_if_not(get_os() == "windows")
- vdiffr::expect_doppelganger("shap plots", result$shap_plots)
-})
-
testthat::test_that("MLCovSearch returns an object of class `mlcov_data` with 5 components", {
testthat::expect_true(inherits(result, "mlcov_data"))
From 6d55bd278beec982c58713c14427cc4de9735282 Mon Sep 17 00:00:00 2001
From: certara-mtalley <150705449+certara-mtalley@users.noreply.github.com>
Date: Wed, 6 Dec 2023 16:43:54 -0800
Subject: [PATCH 3/5] Tests for generate shap plots
Created tests for new standalone shap plot function
---
.../custom-shap-plots.svg | 431 ++++++++++++++++++
.../default-shap-plots.svg | 431 ++++++++++++++++++
.../test-generate_shap_summary_plot.R | 58 +++
3 files changed, 920 insertions(+)
create mode 100644 tests/testthat/_snaps/generate_shap_summary_plot/custom-shap-plots.svg
create mode 100644 tests/testthat/_snaps/generate_shap_summary_plot/default-shap-plots.svg
create mode 100644 tests/testthat/test-generate_shap_summary_plot.R
diff --git a/tests/testthat/_snaps/generate_shap_summary_plot/custom-shap-plots.svg b/tests/testthat/_snaps/generate_shap_summary_plot/custom-shap-plots.svg
new file mode 100644
index 0000000..d572fee
--- /dev/null
+++ b/tests/testthat/_snaps/generate_shap_summary_plot/custom-shap-plots.svg
@@ -0,0 +1,431 @@
+
+
diff --git a/tests/testthat/_snaps/generate_shap_summary_plot/default-shap-plots.svg b/tests/testthat/_snaps/generate_shap_summary_plot/default-shap-plots.svg
new file mode 100644
index 0000000..6ef9ef6
--- /dev/null
+++ b/tests/testthat/_snaps/generate_shap_summary_plot/default-shap-plots.svg
@@ -0,0 +1,431 @@
+
+
diff --git a/tests/testthat/test-generate_shap_summary_plot.R b/tests/testthat/test-generate_shap_summary_plot.R
new file mode 100644
index 0000000..245663a
--- /dev/null
+++ b/tests/testthat/test-generate_shap_summary_plot.R
@@ -0,0 +1,58 @@
+get_os <- function(){
+ sysinf <- Sys.info()
+ if (!is.null(sysinf)){
+ os <- sysinf['sysname']
+ if (os == 'Darwin')
+ os <- "osx"
+ } else { ## mystery machine
+ os <- .Platform$OS.type
+ if (grepl("^darwin", R.version$os))
+ os <- "osx"
+ if (grepl("linux-gnu", R.version$os))
+ os <- "linux"
+ }
+ tolower(os)
+}
+
+# Read in tab2 dataset
+data <- read.table(system.file(package = "mlcov", "supplementary", "tab2"), skip = 1, header = T)
+
+# Search and select covariates. This function can take a few minutes to run
+result <- suppressWarnings(MLCovSearch(data, #NONMEM output
+ list_pop_param = c("V1","CL"),
+ cov_continuous = c("AGE","WT","HT","BMI","ALB","CRT",
+ "FER","CHOL","WBC","LYPCT","RBC",
+ "HGB","HCT","PLT"),
+ cov_factors = c("SEX","RACE","DIAB","ALQ","WACT","SMQ")))
+
+testthat::test_that("generate_shap_plots default plots do not change", {
+ testthat::skip_if_not(get_os() == "windows")
+
+ suppressWarnings(test.plots.default <-
+ generate_shap_summary_plot(result))
+
+ vdiffr::expect_doppelganger("default shap plots", test.plots.default)
+})
+
+testthat::test_that("generate_shap_plots custom plots do not change", {
+ testthat::skip_if_not(get_os() == "windows")
+
+ suppressWarnings(
+ test.plots.custom <- generate_shap_summary_plot(
+ result,
+ x_bound = NULL,
+ dilute = TRUE,
+ scientific = TRUE,
+ my_format = NULL,
+ min_color_bound = "#336699",
+ max_color_bound = "#CC0066",
+ kind = "sina",
+ title = "Test Title",
+ title.position = 0.5,
+ ylab = "Test y-axis",
+ xlab = "Test x-axis"
+ )
+ )
+
+ vdiffr::expect_doppelganger("custom shap plots", test.plots.default)
+})
\ No newline at end of file
From ee81b17eb0d75943c6dbec8d7269fdd430558070 Mon Sep 17 00:00:00 2001
From: certara-mtalley <150705449+certara-mtalley@users.noreply.github.com>
Date: Wed, 6 Dec 2023 19:52:15 -0800
Subject: [PATCH 4/5] Updated tests
Had some errors that weren't caught until by normal devtools::test
---
.../test-generate_shap_summary_plot.R | 36 +++++++++----------
1 file changed, 18 insertions(+), 18 deletions(-)
diff --git a/tests/testthat/test-generate_shap_summary_plot.R b/tests/testthat/test-generate_shap_summary_plot.R
index 245663a..ab4b783 100644
--- a/tests/testthat/test-generate_shap_summary_plot.R
+++ b/tests/testthat/test-generate_shap_summary_plot.R
@@ -28,8 +28,7 @@ result <- suppressWarnings(MLCovSearch(data, #NONMEM output
testthat::test_that("generate_shap_plots default plots do not change", {
testthat::skip_if_not(get_os() == "windows")
- suppressWarnings(test.plots.default <-
- generate_shap_summary_plot(result))
+ test.plots.default <- suppressWarnings(generate_shap_summary_plot(result))
vdiffr::expect_doppelganger("default shap plots", test.plots.default)
})
@@ -37,22 +36,23 @@ testthat::test_that("generate_shap_plots default plots do not change", {
testthat::test_that("generate_shap_plots custom plots do not change", {
testthat::skip_if_not(get_os() == "windows")
- suppressWarnings(
- test.plots.custom <- generate_shap_summary_plot(
- result,
- x_bound = NULL,
- dilute = TRUE,
- scientific = TRUE,
- my_format = NULL,
- min_color_bound = "#336699",
- max_color_bound = "#CC0066",
- kind = "sina",
- title = "Test Title",
- title.position = 0.5,
- ylab = "Test y-axis",
- xlab = "Test x-axis"
+ test.plots.custom <-
+ suppressWarnings(
+ generate_shap_summary_plot(
+ result,
+ x_bound = NULL,
+ dilute = TRUE,
+ scientific = TRUE,
+ my_format = NULL,
+ min_color_bound = "#336699",
+ max_color_bound = "#CC0066",
+ kind = "sina",
+ title = "Test Title",
+ title.position = 0.5,
+ ylab = "Test y-axis",
+ xlab = "Test x-axis"
+ )
)
- )
- vdiffr::expect_doppelganger("custom shap plots", test.plots.default)
+ vdiffr::expect_doppelganger("custom shap plots", test.plots.custom)
})
\ No newline at end of file
From f039298fb5f6d49a6e06870815b03b792681b7f7 Mon Sep 17 00:00:00 2001
From: certara-mtalley <150705449+certara-mtalley@users.noreply.github.com>
Date: Wed, 6 Dec 2023 20:20:38 -0800
Subject: [PATCH 5/5] svg plot update
---
.../custom-shap-plots.svg | 577 ++++++------------
1 file changed, 196 insertions(+), 381 deletions(-)
diff --git a/tests/testthat/_snaps/generate_shap_summary_plot/custom-shap-plots.svg b/tests/testthat/_snaps/generate_shap_summary_plot/custom-shap-plots.svg
index d572fee..0853488 100644
--- a/tests/testthat/_snaps/generate_shap_summary_plot/custom-shap-plots.svg
+++ b/tests/testthat/_snaps/generate_shap_summary_plot/custom-shap-plots.svg
@@ -27,398 +27,213 @@
-
-
-
-
+
+
+
+
-
-
+
-
-
+
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-0.238
-0.173
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+1.7e-01
+2.4e-01
WT
AGE
-
-
+
-
-
--1.0
--0.5
+
+-0.5
0.0
-0.5
-1.0
-SHAP value (impact on model output)
+0.5
+Test y-axis
+Test x-axis
-
+
Low
High
Feature value
@@ -426,6 +241,6 @@
-V1
+Test Title