Skip to content

Commit

Permalink
disable SHAP for NN
Browse files Browse the repository at this point in the history
  • Loading branch information
pplonski committed Jul 9, 2020
1 parent 495d82f commit 01e8639
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions supervised/utils/shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ def get_explainer(algorithm, X_train):
explainer = shap.TreeExplainer(algorithm.model)
elif algorithm.algorithm_short_name in ["Linear"]:
explainer = shap.LinearExplainer(algorithm.model, X_train)
elif algorithm.algorithm_short_name in ["Neural Network"]:
explainer = shap.KernelExplainer(algorithm.model, X_train)

#elif algorithm.algorithm_short_name in ["Neural Network"]:
# explainer = shap.GradientExplainer(algorithm.model, X_train) #, link="logit")
return explainer

@staticmethod
Expand Down Expand Up @@ -186,6 +186,7 @@ def compute(
explainer = PlotSHAP.get_explainer(algorithm, X_train)
X_vald, y_vald = PlotSHAP.get_sample(X_validation, y_validation)
shap_values = explainer.shap_values(X_vald)


# fix problem with 1 or 2 dimensions for binary classification
expected_value = explainer.expected_value
Expand Down Expand Up @@ -245,6 +246,7 @@ def compute(
class_names,
)
except Exception as e:
print(f"Exception while producing SHAP explanations. {str(e)}\nContinuing ...")
logger.info(
f"Exception while producing SHAP explanations. {str(e)}\nContinuing ..."
)
Expand Down

0 comments on commit 01e8639

Please sign in to comment.