diff --git a/NAMESPACE b/NAMESPACE index d22810c..4198ea8 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -3,3 +3,7 @@ S3method(print,mlcov_data) export(MLCovSearch) export(generate_residualsplots) +export(generate_shap_summary_plot) +import(SHAPforxgboost) +import(ggplot2) +import(xgboost) diff --git a/R/cov_search.R b/R/cov_search.R index 31a6193..3123449 100644 --- a/R/cov_search.R +++ b/R/cov_search.R @@ -199,60 +199,48 @@ MLCovSearch <- function(tab, list_pop_param, cov_continuous, cov_factors, seed = } } - # Initialize an empty list to store the SHAP summary plots - shap_plots <- list() - - # Create a function to generate SHAP summary plots - generate_shap_summary_plot <- function(xgb_model, X_train, param_name) { - shap_values <- SHAPforxgboost::shap.values(xgb_model = xgb_model, X_train = X_train) - shap_long <- SHAPforxgboost::shap.prep(xgb_model = xgb_model, X_train = X_train) - p <- SHAPforxgboost::shap.plot.summary(shap_long) - p <- p + ggplot2::ggtitle(param_name) - - return(p) - } + # Initialize an empty list to store the SHAP summary data and seed information + shap_data <- list() + shap_seed <- list() # Interpretation of Selected covariates Beeswarm Plots for (i in list_pop_param) { y_xgb <- log(dat_XGB[, i]) - - if (is.na(result_ML[i, 1]) == FALSE){ - list_cov <- strsplit(gsub(" ", "", result_ML[i, 1]), ",") - x.selected_final <- as.matrix(dat_XGB %>% dplyr::select(dplyr::all_of(list_cov[[1]]))) - - if (length(list_cov[[1]]) != 0 ) { - xgb.mod_final <- xgboost::xgboost( - data = x.selected_final, - label = y_xgb, - nrounds = 200, - objective = "reg:squarederror", - verbose = 0 - ) - - # Generate SHAP summary plot for the current parameter - shap_plot <- generate_shap_summary_plot( - xgb_model = xgb.mod_final, - X_train = x.selected_final, - param_name = i - ) - shap_plots[[i]] <- shap_plot - } + + if (is.na(result_ML[i, 1]) == FALSE) { + list_cov <- strsplit(gsub(" ", "", result_ML[i, 1]), ",") + x.selected_final <- + as.matrix(dat_XGB %>% dplyr::select(dplyr::all_of(list_cov[[1]]))) + + if (length(list_cov[[1]]) != 0) { + xgb.mod_final <- xgboost::xgboost( + data = x.selected_final, + label = y_xgb, + nrounds = 200, + objective = "reg:squarederror", + verbose = 0 + ) + + # Generate SHAP summary plot for the current parameter + shap_values <- SHAPforxgboost::shap.values(xgb_model = xgb.mod_final, X_train = x.selected_final) + shap_long <- SHAPforxgboost::shap.prep(xgb_model = xgb.mod_final, X_train = x.selected_final) + + # Store shap data and seed + shap_data[[i]] <- list(shap_values = shap_values, shap_long = shap_long) + shap_seed[[i]] <- .Random.seed + + } } } - combined_plots <- gridExtra::marrangeGrob(grobs = shap_plots,nrow = length(shap_plots),ncol = 1) - - - - # Return the result_ML table and the SHAP plots for each parameter return( list( result_ML = result_ML, result_5folds = result_5folds, - shap_plots = shap_plots, list_pop_param = list_pop_param, - dat_XGB = dat_XGB + shap_data = shap_data, + shap_seed = shap_seed ) %>% structure(class = "mlcov_data") ) } diff --git a/R/generate_shap_summary_plot.R b/R/generate_shap_summary_plot.R new file mode 100644 index 0000000..f86740f --- /dev/null +++ b/R/generate_shap_summary_plot.R @@ -0,0 +1,71 @@ +#' Generate SHAP Summary Plots +#' +#' This function generates SHAP summary plots for the XGBoost model. +#' @inheritParams SHAPforxgboost::shap.plot.summary +#' @param data A list containing required data frames and results. +#' @param title A character string to customize the title +#' @param title.position A numeric value from 0.0-1.0 to adjust the alignment of the title +#' @param ylab A character string to customize the y-axis +#' @param xlab A character string to customize the x-axis +#' @return A list of ggplot objects, each representing a SHAP summary plot for a different parameter. +#' +#' @examples +#' # Assuming 'data' is a list with necessary components +#' \dontrun{ +#' plots <- generate_shap_summary_plot(data, title = "Custom Title", y.inter = 0.25, ...) +#' } +#' +#' @import xgboost +#' @import ggplot2 +#' @import SHAPforxgboost +#' +#' @export +#' +generate_shap_summary_plot <- function(data, + x_bound = NULL, + dilute = FALSE, + scientific = FALSE, + my_format = NULL, + min_color_bound = "#FFCC33", + max_color_bound = "#6600CC", + kind = c("sina", "bar"), + title = NULL, + title.position = 0, + ylab = NULL, + xlab = NULL) +{ + # Initialize an empty list to store the SHAP summary plots + shap_plots <- list() + + for (i in data$list_pop_param) { + if (!is.null(data$shap_data[[i]])) { + # Set to the seed used when making the data + set.seed(data$shap_seed[[i]]) + + shap_values <- data$shap_data[[i]]$shap_values + shap_long <- data$shap_data[[i]]$shap_long + + # Generate summary plot + p <- + SHAPforxgboost::shap.plot.summary( + shap_long, + x_bound = x_bound, + dilute = dilute, + scientific = scientific, + my_format = my_format, + min_color_bound = min_color_bound, + max_color_bound = max_color_bound, + kind = kind + ) + p <- p + + ggplot2::ggtitle(ifelse(is.null(title), i, title)) + + ggplot2::labs(y = ifelse(is.null(ylab), "SHAP value (impact on model output)", ylab), + x = ifelse(is.null(xlab), "", xlab)) + + ggplot2::theme(plot.title = ggplot2::element_text(hjust = title.position)) + shap_plots[[i]] <- p + } + + } + + return(shap_plots) +} diff --git a/README.md b/README.md index 719c3e7..72c985e 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,21 @@ print(result) ``` +Generate SHAP plots: + +``` +generate_shap_summary_plot( + result, + x_bound = NULL, + dilute = FALSE, + scientific = FALSE, + my_format = NULL, + title = NULL, + title.position = 0.5, + ylab = NULL, + xlab = NULL) +``` + Generate residual plots: Cl diff --git a/man/generate_shap_summary_plot.Rd b/man/generate_shap_summary_plot.Rd new file mode 100644 index 0000000..045eb6e --- /dev/null +++ b/man/generate_shap_summary_plot.Rd @@ -0,0 +1,70 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/generate_shap_summary_plot.R +\name{generate_shap_summary_plot} +\alias{generate_shap_summary_plot} +\title{Generate SHAP Summary Plots} +\usage{ +generate_shap_summary_plot( + data, + x_bound = NULL, + dilute = FALSE, + scientific = FALSE, + my_format = NULL, + min_color_bound = "#FFCC33", + max_color_bound = "#6600CC", + kind = c("sina", "bar"), + title = NULL, + title.position = 0, + ylab = NULL, + xlab = NULL +) +} +\arguments{ +\item{data}{A list containing required data frames and results.} + +\item{x_bound}{use to set horizontal axis limit in the plot} + +\item{dilute}{being numeric or logical (TRUE/FALSE), it aims to help make the test +plot for large amount of data faster. If dilute = 5 will plot 1/5 of the +data. If dilute = TRUE or a number, will plot at most half points per +feature, so the plotting won't be too slow. If you put dilute too high, at +least 10 points per feature would be kept. If the dataset is too small +after dilution, will just plot all the data} + +\item{scientific}{show the mean|SHAP| in scientific format. If TRUE, label +format is 0.0E-0, default to FALSE, and the format will be 0.000} + +\item{my_format}{supply your own number format if you really want} + +\item{min_color_bound}{min color hex code for colormap. Color gradient is +scaled between min_color_bound and max_color_bound. Default is "#FFCC33".} + +\item{max_color_bound}{max color hex code for colormap. Color gradient is +scaled between min_color_bound and max_color_bound. Default is "#6600CC".} + +\item{kind}{By default, a "sina" plot is shown. As an alternative, +set \code{kind = "bar"} to visualize mean absolute SHAP values as a +barplot. Its color is controlled by \code{max_color_bound}. Other +arguments are ignored for this kind of plot.} + +\item{title}{A character string to customize the title} + +\item{title.position}{A numeric value from 0.0-1.0 to adjust the alignment of the title} + +\item{ylab}{A character string to customize the y-axis} + +\item{xlab}{A character string to customize the x-axis} +} +\value{ +A list of ggplot objects, each representing a SHAP summary plot for a different parameter. +} +\description{ +This function generates SHAP summary plots for the XGBoost model. +} +\examples{ +# Assuming 'data' is a list with necessary components +\dontrun{ +plots <- generate_shap_summary_plot(data, title = "Custom Title", y.inter = 0.25, ...) +} + +} diff --git a/tests/testthat/_snaps/generate_shap_summary_plot/custom-shap-plots.svg b/tests/testthat/_snaps/generate_shap_summary_plot/custom-shap-plots.svg new file mode 100644 index 0000000..0853488 --- /dev/null +++ b/tests/testthat/_snaps/generate_shap_summary_plot/custom-shap-plots.svg @@ -0,0 +1,246 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +1.7e-01 +2.4e-01 + + + +WT +AGE + + + +-0.5 +0.0 +0.5 +Test y-axis +Test x-axis + + + Low +High +Feature value + + + + +Test Title + + diff --git a/tests/testthat/_snaps/mlcovsearch/shap-plots.svg b/tests/testthat/_snaps/generate_shap_summary_plot/default-shap-plots.svg similarity index 61% rename from tests/testthat/_snaps/mlcovsearch/shap-plots.svg rename to tests/testthat/_snaps/generate_shap_summary_plot/default-shap-plots.svg index 7b7c2b9..6ef9ef6 100644 --- a/tests/testthat/_snaps/mlcovsearch/shap-plots.svg +++ b/tests/testthat/_snaps/generate_shap_summary_plot/default-shap-plots.svg @@ -39,366 +39,366 @@ - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - - - - - - + + + + + + + + - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 0.238 0.173 diff --git a/tests/testthat/test-generate_shap_summary_plot.R b/tests/testthat/test-generate_shap_summary_plot.R new file mode 100644 index 0000000..ab4b783 --- /dev/null +++ b/tests/testthat/test-generate_shap_summary_plot.R @@ -0,0 +1,58 @@ +get_os <- function(){ + sysinf <- Sys.info() + if (!is.null(sysinf)){ + os <- sysinf['sysname'] + if (os == 'Darwin') + os <- "osx" + } else { ## mystery machine + os <- .Platform$OS.type + if (grepl("^darwin", R.version$os)) + os <- "osx" + if (grepl("linux-gnu", R.version$os)) + os <- "linux" + } + tolower(os) +} + +# Read in tab2 dataset +data <- read.table(system.file(package = "mlcov", "supplementary", "tab2"), skip = 1, header = T) + +# Search and select covariates. This function can take a few minutes to run +result <- suppressWarnings(MLCovSearch(data, #NONMEM output + list_pop_param = c("V1","CL"), + cov_continuous = c("AGE","WT","HT","BMI","ALB","CRT", + "FER","CHOL","WBC","LYPCT","RBC", + "HGB","HCT","PLT"), + cov_factors = c("SEX","RACE","DIAB","ALQ","WACT","SMQ"))) + +testthat::test_that("generate_shap_plots default plots do not change", { + testthat::skip_if_not(get_os() == "windows") + + test.plots.default <- suppressWarnings(generate_shap_summary_plot(result)) + + vdiffr::expect_doppelganger("default shap plots", test.plots.default) +}) + +testthat::test_that("generate_shap_plots custom plots do not change", { + testthat::skip_if_not(get_os() == "windows") + + test.plots.custom <- + suppressWarnings( + generate_shap_summary_plot( + result, + x_bound = NULL, + dilute = TRUE, + scientific = TRUE, + my_format = NULL, + min_color_bound = "#336699", + max_color_bound = "#CC0066", + kind = "sina", + title = "Test Title", + title.position = 0.5, + ylab = "Test y-axis", + xlab = "Test x-axis" + ) + ) + + vdiffr::expect_doppelganger("custom shap plots", test.plots.custom) +}) \ No newline at end of file diff --git a/tests/testthat/test-mlcovsearch.R b/tests/testthat/test-mlcovsearch.R index 94db71f..c9a7564 100644 --- a/tests/testthat/test-mlcovsearch.R +++ b/tests/testthat/test-mlcovsearch.R @@ -39,12 +39,6 @@ testthat::test_that("MLCovSearch result_5folds is a data.frame with 5 non-NA cov testthat::expect_true(sum(!is.na(result$result_5folds[1,])) == 5) }) -# This test will need removed when shap plots are removed from MLCovSearch -testthat::test_that("MLCovSearch shap plots do not change", { - testthat::skip_if_not(get_os() == "windows") - vdiffr::expect_doppelganger("shap plots", result$shap_plots) -}) - testthat::test_that("MLCovSearch returns an object of class `mlcov_data` with 5 components", { testthat::expect_true(inherits(result, "mlcov_data"))