Skip to content

Commit

Permalink
Add XGB explainability output (NVIDIA#3044)
Browse files Browse the repository at this point in the history
* Add XGB explainability output

* typo fix

* format fix
  • Loading branch information
ZiyueXu77 authored and YuanTingHsieh committed Dec 10, 2024
1 parent bc46189 commit aa77116
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 5 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,19 @@ def __init__(self, root_dir: str, file_postfix: str):
self.file_postfix = file_postfix
for name in self.dataset_names:
self.base_file_names[name] = name + file_postfix
self.numerical_columns = [f"V_{i}" for i in range(64)]

self.numerical_columns = [
"Timestamp",
"Amount",
"trans_volume",
"total_amount",
"average_amount",
"hist_trans_volume",
"hist_total_amount",
"hist_average_amount",
"x2_y1",
"x3_y2",
] + [f"V_{i}" for i in range(64)]

def initialize(
self, client_id: str, rank: int, data_split_mode: xgb.core.DataSplitMode = xgb.core.DataSplitMode.ROW
Expand All @@ -40,11 +52,10 @@ def initialize(
def load_data(self) -> Tuple[xgb.DMatrix, xgb.DMatrix]:
data = {}
for ds_name in self.dataset_names:
print("\nloading for site = ", self.client_id, f"{ds_name} dataset")
print("\nloading for site = ", self.client_id, f"{ds_name} dataset \n")
file_name = os.path.join(self.root_dir, self.client_id, self.base_file_names[ds_name])
print(file_name)
print(self.numerical_columns)
print("\n")
df = pd.read_csv(file_name)
data_num = len(data)

Expand Down
4 changes: 2 additions & 2 deletions examples/advanced/finance-end-to-end/nvflare/xgb_job_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def define_parser():
"--file_postfix",
type=str,
nargs="?",
default="_embedding.csv",
help="file ending postfix, such as '.csv', or '_embedding.csv'",
default="_combined.csv",
help="file ending postfix, such as '.csv', or '_combined.csv'",
)

parser.add_argument("-co", "--config_only", action="store_true", help="config only mode, will not run simulator")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import os
from typing import Tuple

import matplotlib.pyplot as plt
import shap
import xgboost as xgb
from xgboost import callback

Expand Down Expand Up @@ -222,6 +224,16 @@ def run(self, ctx: dict):
bst.save_model(os.path.join(self._model_dir, self.model_file_name))
xgb.collective.communicator_print("Finished training\n")

# Save explanability outputs based on val_data
explainer = shap.TreeExplainer(bst)
explanation = explainer(val_data)

# save the beeswarm plot to png file
shap.plots.beeswarm(explanation, show=False)
img = plt.gcf()
img.subplots_adjust(left=0.3, right=0.9, bottom=0.3, top=0.9)
img.savefig(os.path.join(self._model_dir, "shap_beeswarm.png"), bbox_inches="tight")

self._stopped = True

def stop(self):
Expand Down

0 comments on commit aa77116

Please sign in to comment.