diff --git a/NAMESPACE b/NAMESPACE
index d22810c..4198ea8 100644
--- a/NAMESPACE
+++ b/NAMESPACE
@@ -3,3 +3,7 @@
S3method(print,mlcov_data)
export(MLCovSearch)
export(generate_residualsplots)
+export(generate_shap_summary_plot)
+import(SHAPforxgboost)
+import(ggplot2)
+import(xgboost)
diff --git a/R/cov_search.R b/R/cov_search.R
index 31a6193..3123449 100644
--- a/R/cov_search.R
+++ b/R/cov_search.R
@@ -199,60 +199,48 @@ MLCovSearch <- function(tab, list_pop_param, cov_continuous, cov_factors, seed =
}
}
- # Initialize an empty list to store the SHAP summary plots
- shap_plots <- list()
-
- # Create a function to generate SHAP summary plots
- generate_shap_summary_plot <- function(xgb_model, X_train, param_name) {
- shap_values <- SHAPforxgboost::shap.values(xgb_model = xgb_model, X_train = X_train)
- shap_long <- SHAPforxgboost::shap.prep(xgb_model = xgb_model, X_train = X_train)
- p <- SHAPforxgboost::shap.plot.summary(shap_long)
- p <- p + ggplot2::ggtitle(param_name)
-
- return(p)
- }
+ # Initialize an empty list to store the SHAP summary data and seed information
+ shap_data <- list()
+ shap_seed <- list()
# Interpretation of Selected covariates Beeswarm Plots
for (i in list_pop_param) {
y_xgb <- log(dat_XGB[, i])
-
- if (is.na(result_ML[i, 1]) == FALSE){
- list_cov <- strsplit(gsub(" ", "", result_ML[i, 1]), ",")
- x.selected_final <- as.matrix(dat_XGB %>% dplyr::select(dplyr::all_of(list_cov[[1]])))
-
- if (length(list_cov[[1]]) != 0 ) {
- xgb.mod_final <- xgboost::xgboost(
- data = x.selected_final,
- label = y_xgb,
- nrounds = 200,
- objective = "reg:squarederror",
- verbose = 0
- )
-
- # Generate SHAP summary plot for the current parameter
- shap_plot <- generate_shap_summary_plot(
- xgb_model = xgb.mod_final,
- X_train = x.selected_final,
- param_name = i
- )
- shap_plots[[i]] <- shap_plot
- }
+
+ if (is.na(result_ML[i, 1]) == FALSE) {
+ list_cov <- strsplit(gsub(" ", "", result_ML[i, 1]), ",")
+ x.selected_final <-
+ as.matrix(dat_XGB %>% dplyr::select(dplyr::all_of(list_cov[[1]])))
+
+ if (length(list_cov[[1]]) != 0) {
+ xgb.mod_final <- xgboost::xgboost(
+ data = x.selected_final,
+ label = y_xgb,
+ nrounds = 200,
+ objective = "reg:squarederror",
+ verbose = 0
+ )
+
+ # Generate SHAP summary plot for the current parameter
+ shap_values <- SHAPforxgboost::shap.values(xgb_model = xgb.mod_final, X_train = x.selected_final)
+ shap_long <- SHAPforxgboost::shap.prep(xgb_model = xgb.mod_final, X_train = x.selected_final)
+
+ # Store shap data and seed
+ shap_data[[i]] <- list(shap_values = shap_values, shap_long = shap_long)
+ shap_seed[[i]] <- .Random.seed
+
+ }
}
}
- combined_plots <- gridExtra::marrangeGrob(grobs = shap_plots,nrow = length(shap_plots),ncol = 1)
-
-
-
-
# Return the result_ML table and the SHAP plots for each parameter
return(
list(
result_ML = result_ML,
result_5folds = result_5folds,
- shap_plots = shap_plots,
list_pop_param = list_pop_param,
- dat_XGB = dat_XGB
+ shap_data = shap_data,
+ shap_seed = shap_seed
) %>% structure(class = "mlcov_data")
)
}
diff --git a/R/generate_shap_summary_plot.R b/R/generate_shap_summary_plot.R
new file mode 100644
index 0000000..f86740f
--- /dev/null
+++ b/R/generate_shap_summary_plot.R
@@ -0,0 +1,71 @@
+#' Generate SHAP Summary Plots
+#'
+#' This function generates SHAP summary plots for the XGBoost model.
+#' @inheritParams SHAPforxgboost::shap.plot.summary
+#' @param data A list containing required data frames and results.
+#' @param title A character string to customize the title
+#' @param title.position A numeric value from 0.0-1.0 to adjust the alignment of the title
+#' @param ylab A character string to customize the y-axis
+#' @param xlab A character string to customize the x-axis
+#' @return A list of ggplot objects, each representing a SHAP summary plot for a different parameter.
+#'
+#' @examples
+#' # Assuming 'data' is a list with necessary components
+#' \dontrun{
+#' plots <- generate_shap_summary_plot(data, title = "Custom Title", y.inter = 0.25, ...)
+#' }
+#'
+#' @import xgboost
+#' @import ggplot2
+#' @import SHAPforxgboost
+#'
+#' @export
+#'
+generate_shap_summary_plot <- function(data,
+ x_bound = NULL,
+ dilute = FALSE,
+ scientific = FALSE,
+ my_format = NULL,
+ min_color_bound = "#FFCC33",
+ max_color_bound = "#6600CC",
+ kind = c("sina", "bar"),
+ title = NULL,
+ title.position = 0,
+ ylab = NULL,
+ xlab = NULL)
+{
+ # Initialize an empty list to store the SHAP summary plots
+ shap_plots <- list()
+
+ for (i in data$list_pop_param) {
+ if (!is.null(data$shap_data[[i]])) {
+ # Set to the seed used when making the data
+ set.seed(data$shap_seed[[i]])
+
+ shap_values <- data$shap_data[[i]]$shap_values
+ shap_long <- data$shap_data[[i]]$shap_long
+
+ # Generate summary plot
+ p <-
+ SHAPforxgboost::shap.plot.summary(
+ shap_long,
+ x_bound = x_bound,
+ dilute = dilute,
+ scientific = scientific,
+ my_format = my_format,
+ min_color_bound = min_color_bound,
+ max_color_bound = max_color_bound,
+ kind = kind
+ )
+ p <- p +
+ ggplot2::ggtitle(ifelse(is.null(title), i, title)) +
+ ggplot2::labs(y = ifelse(is.null(ylab), "SHAP value (impact on model output)", ylab),
+ x = ifelse(is.null(xlab), "", xlab)) +
+ ggplot2::theme(plot.title = ggplot2::element_text(hjust = title.position))
+ shap_plots[[i]] <- p
+ }
+
+ }
+
+ return(shap_plots)
+}
diff --git a/README.md b/README.md
index 719c3e7..72c985e 100644
--- a/README.md
+++ b/README.md
@@ -22,6 +22,21 @@ print(result)
```
+Generate SHAP plots:
+
+```
+generate_shap_summary_plot(
+ result,
+ x_bound = NULL,
+ dilute = FALSE,
+ scientific = FALSE,
+ my_format = NULL,
+ title = NULL,
+ title.position = 0.5,
+ ylab = NULL,
+ xlab = NULL)
+```
+
Generate residual plots:
Cl
diff --git a/man/generate_shap_summary_plot.Rd b/man/generate_shap_summary_plot.Rd
new file mode 100644
index 0000000..045eb6e
--- /dev/null
+++ b/man/generate_shap_summary_plot.Rd
@@ -0,0 +1,70 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/generate_shap_summary_plot.R
+\name{generate_shap_summary_plot}
+\alias{generate_shap_summary_plot}
+\title{Generate SHAP Summary Plots}
+\usage{
+generate_shap_summary_plot(
+ data,
+ x_bound = NULL,
+ dilute = FALSE,
+ scientific = FALSE,
+ my_format = NULL,
+ min_color_bound = "#FFCC33",
+ max_color_bound = "#6600CC",
+ kind = c("sina", "bar"),
+ title = NULL,
+ title.position = 0,
+ ylab = NULL,
+ xlab = NULL
+)
+}
+\arguments{
+\item{data}{A list containing required data frames and results.}
+
+\item{x_bound}{use to set horizontal axis limit in the plot}
+
+\item{dilute}{being numeric or logical (TRUE/FALSE), it aims to help make the test
+plot for large amount of data faster. If dilute = 5 will plot 1/5 of the
+data. If dilute = TRUE or a number, will plot at most half points per
+feature, so the plotting won't be too slow. If you put dilute too high, at
+least 10 points per feature would be kept. If the dataset is too small
+after dilution, will just plot all the data}
+
+\item{scientific}{show the mean|SHAP| in scientific format. If TRUE, label
+format is 0.0E-0, default to FALSE, and the format will be 0.000}
+
+\item{my_format}{supply your own number format if you really want}
+
+\item{min_color_bound}{min color hex code for colormap. Color gradient is
+scaled between min_color_bound and max_color_bound. Default is "#FFCC33".}
+
+\item{max_color_bound}{max color hex code for colormap. Color gradient is
+scaled between min_color_bound and max_color_bound. Default is "#6600CC".}
+
+\item{kind}{By default, a "sina" plot is shown. As an alternative,
+set \code{kind = "bar"} to visualize mean absolute SHAP values as a
+barplot. Its color is controlled by \code{max_color_bound}. Other
+arguments are ignored for this kind of plot.}
+
+\item{title}{A character string to customize the title}
+
+\item{title.position}{A numeric value from 0.0-1.0 to adjust the alignment of the title}
+
+\item{ylab}{A character string to customize the y-axis}
+
+\item{xlab}{A character string to customize the x-axis}
+}
+\value{
+A list of ggplot objects, each representing a SHAP summary plot for a different parameter.
+}
+\description{
+This function generates SHAP summary plots for the XGBoost model.
+}
+\examples{
+# Assuming 'data' is a list with necessary components
+\dontrun{
+plots <- generate_shap_summary_plot(data, title = "Custom Title", y.inter = 0.25, ...)
+}
+
+}
diff --git a/tests/testthat/_snaps/generate_shap_summary_plot/custom-shap-plots.svg b/tests/testthat/_snaps/generate_shap_summary_plot/custom-shap-plots.svg
new file mode 100644
index 0000000..0853488
--- /dev/null
+++ b/tests/testthat/_snaps/generate_shap_summary_plot/custom-shap-plots.svg
@@ -0,0 +1,246 @@
+
+
diff --git a/tests/testthat/_snaps/mlcovsearch/shap-plots.svg b/tests/testthat/_snaps/generate_shap_summary_plot/default-shap-plots.svg
similarity index 61%
rename from tests/testthat/_snaps/mlcovsearch/shap-plots.svg
rename to tests/testthat/_snaps/generate_shap_summary_plot/default-shap-plots.svg
index 7b7c2b9..6ef9ef6 100644
--- a/tests/testthat/_snaps/mlcovsearch/shap-plots.svg
+++ b/tests/testthat/_snaps/generate_shap_summary_plot/default-shap-plots.svg
@@ -39,366 +39,366 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
0.238
0.173
diff --git a/tests/testthat/test-generate_shap_summary_plot.R b/tests/testthat/test-generate_shap_summary_plot.R
new file mode 100644
index 0000000..ab4b783
--- /dev/null
+++ b/tests/testthat/test-generate_shap_summary_plot.R
@@ -0,0 +1,58 @@
+get_os <- function(){
+ sysinf <- Sys.info()
+ if (!is.null(sysinf)){
+ os <- sysinf['sysname']
+ if (os == 'Darwin')
+ os <- "osx"
+ } else { ## mystery machine
+ os <- .Platform$OS.type
+ if (grepl("^darwin", R.version$os))
+ os <- "osx"
+ if (grepl("linux-gnu", R.version$os))
+ os <- "linux"
+ }
+ tolower(os)
+}
+
+# Read in tab2 dataset
+data <- read.table(system.file(package = "mlcov", "supplementary", "tab2"), skip = 1, header = T)
+
+# Search and select covariates. This function can take a few minutes to run
+result <- suppressWarnings(MLCovSearch(data, #NONMEM output
+ list_pop_param = c("V1","CL"),
+ cov_continuous = c("AGE","WT","HT","BMI","ALB","CRT",
+ "FER","CHOL","WBC","LYPCT","RBC",
+ "HGB","HCT","PLT"),
+ cov_factors = c("SEX","RACE","DIAB","ALQ","WACT","SMQ")))
+
+testthat::test_that("generate_shap_plots default plots do not change", {
+ testthat::skip_if_not(get_os() == "windows")
+
+ test.plots.default <- suppressWarnings(generate_shap_summary_plot(result))
+
+ vdiffr::expect_doppelganger("default shap plots", test.plots.default)
+})
+
+testthat::test_that("generate_shap_plots custom plots do not change", {
+ testthat::skip_if_not(get_os() == "windows")
+
+ test.plots.custom <-
+ suppressWarnings(
+ generate_shap_summary_plot(
+ result,
+ x_bound = NULL,
+ dilute = TRUE,
+ scientific = TRUE,
+ my_format = NULL,
+ min_color_bound = "#336699",
+ max_color_bound = "#CC0066",
+ kind = "sina",
+ title = "Test Title",
+ title.position = 0.5,
+ ylab = "Test y-axis",
+ xlab = "Test x-axis"
+ )
+ )
+
+ vdiffr::expect_doppelganger("custom shap plots", test.plots.custom)
+})
\ No newline at end of file
diff --git a/tests/testthat/test-mlcovsearch.R b/tests/testthat/test-mlcovsearch.R
index 94db71f..c9a7564 100644
--- a/tests/testthat/test-mlcovsearch.R
+++ b/tests/testthat/test-mlcovsearch.R
@@ -39,12 +39,6 @@ testthat::test_that("MLCovSearch result_5folds is a data.frame with 5 non-NA cov
testthat::expect_true(sum(!is.na(result$result_5folds[1,])) == 5)
})
-# This test will need removed when shap plots are removed from MLCovSearch
-testthat::test_that("MLCovSearch shap plots do not change", {
- testthat::skip_if_not(get_os() == "windows")
- vdiffr::expect_doppelganger("shap plots", result$shap_plots)
-})
-
testthat::test_that("MLCovSearch returns an object of class `mlcov_data` with 5 components", {
testthat::expect_true(inherits(result, "mlcov_data"))