diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_client_runner.py b/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_client_runner.py index 66a52a100a..f4fb01a441 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_client_runner.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_client_runner.py @@ -15,6 +15,8 @@ import os from typing import Tuple +import matplotlib.pyplot as plt +import shap import xgboost as xgb from xgboost import callback @@ -222,6 +224,17 @@ def run(self, ctx: dict): bst.save_model(os.path.join(self._model_dir, self.model_file_name)) xgb.collective.communicator_print("Finished training\n") + if self._data_split_mode == 0: + # Save explainability outputs based on val_data + explainer = shap.TreeExplainer(bst) + explanation = explainer(val_data) + + # save the beeswarm plot to png file + shap.plots.beeswarm(explanation, show=False) + img = plt.gcf() + img.subplots_adjust(left=0.3, right=0.9, bottom=0.3, top=0.9) + img.savefig(os.path.join(self._model_dir, "shap_beeswarm.png"), bbox_inches="tight") + self._stopped = True def stop(self):