Skip to content

Commit

Permalink
cherry pick explanability part to 2.5
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueXu77 committed Dec 10, 2024
1 parent 397983b commit 0dc3069
Showing 1 changed file with 13 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import os
from typing import Tuple

import matplotlib.pyplot as plt
import shap
import xgboost as xgb
from xgboost import callback

Expand Down Expand Up @@ -222,6 +224,17 @@ def run(self, ctx: dict):
bst.save_model(os.path.join(self._model_dir, self.model_file_name))
xgb.collective.communicator_print("Finished training\n")

if self._data_split_mode == 0:
# Save explainability outputs based on val_data
explainer = shap.TreeExplainer(bst)
explanation = explainer(val_data)

# save the beeswarm plot to png file
shap.plots.beeswarm(explanation, show=False)
img = plt.gcf()
img.subplots_adjust(left=0.3, right=0.9, bottom=0.3, top=0.9)
img.savefig(os.path.join(self._model_dir, "shap_beeswarm.png"), bbox_inches="tight")

self._stopped = True

def stop(self):
Expand Down

0 comments on commit 0dc3069

Please sign in to comment.