Skip to content

Commit

Permalink
Add TreeShap explainer
Browse files Browse the repository at this point in the history
  • Loading branch information
ukclivecox authored and seldondev committed Sep 2, 2020
1 parent 2285b68 commit 094bf25
Show file tree
Hide file tree
Showing 8 changed files with 422 additions and 11 deletions.
11 changes: 11 additions & 0 deletions components/alibi-explain-server/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,14 @@ run_explainer_integratedgradients_docker:

curl_explain_imdb:
curl -d '{"data": {"ndarray":[[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 591, 202, 14, 31, 6, 717, 10, 10, 2, 2, 5, 4, 360, 7, 4, 177, 5760, 394, 354, 4, 123, 9, 1035, 1035, 1035, 10, 10, 13, 92, 124, 89, 488, 7944, 100, 28, 1668, 14, 31, 23, 27, 7479, 29, 220, 468, 8, 124, 14, 286, 170, 8, 157, 46, 5, 27, 239, 16, 179, 2, 38, 32, 25, 7944, 451, 202, 14, 6, 717]]}}' -X POST http://localhost:8080/api/v1.0/explain -H "Content-Type: application/json"


#
# Test Tree Shap
#

run_explainer_treeshap:
python -m alibiexplainer --model_name adult --protocol seldon.http --storage_uri gs://seldon-models/xgboost/adult/tree_shap_py36_0.5.2 TreeShap

run_explainer_treeshap_docker:
docker run --rm -d --name "explainer" --network=host -p 8080:8080 seldonio/${IMAGE}:${VERSION} --model_name adult --protocol seldon.http --storage_uri gs://seldon-models/xgboost/adult/tree_shap_py36_0.5.2 TreeShap
6 changes: 5 additions & 1 deletion components/alibi-explain-server/alibiexplainer/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from alibiexplainer.anchor_tabular import AnchorTabular
from alibiexplainer.anchor_text import AnchorText
from alibiexplainer.kernel_shap import KernelShap
from alibiexplainer.tree_shap import TreeShap
from alibiexplainer.integrated_gradients import IntegratedGradients
from alibiexplainer.explainer_wrapper import ExplainerWrapper
from alibiexplainer.proto import prediction_pb2
Expand Down Expand Up @@ -59,6 +60,7 @@ class ExplainerMethod(Enum):
anchor_text = "AnchorText"
kernel_shap = "KernelShap"
integrated_gradients = "IntegratedGradients"
tree_shap = "TreeShap"

def __str__(self):
return self.value
Expand Down Expand Up @@ -93,6 +95,8 @@ def __init__(self,
self.wrapper = KernelShap(self._predict_fn, explainer, **config)
elif self.method is ExplainerMethod.integrated_gradients:
self.wrapper = IntegratedGradients(keras_model, **config)
elif self.method is ExplainerMethod.tree_shap:
self.wrapper = TreeShap(explainer, **config)
else:
raise NotImplementedError

Expand Down Expand Up @@ -135,7 +139,7 @@ def _predict_fn(self, arr: Union[np.ndarray, List]) -> np.ndarray:
def explain(self, request: Dict) -> Any:
if self.method is ExplainerMethod.anchor_tabular or self.method is ExplainerMethod.anchor_images or \
self.method is ExplainerMethod.anchor_text or self.method is ExplainerMethod.kernel_shap or \
self.method is ExplainerMethod.integrated_gradients:
self.method is ExplainerMethod.integrated_gradients or self.method is ExplainerMethod.tree_shap:
if self.protocol == Protocol.tensorflow_http:
explanation: Explanation = self.wrapper.explain(request["instances"])
else:
Expand Down
45 changes: 45 additions & 0 deletions components/alibi-explain-server/alibiexplainer/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,51 @@ def parse_args(sys_args):
dest="explainer.internal_batch_size",
default=argparse.SUPPRESS,
)

# TreeShap Arguments
parser_tree_shap = subparsers.add_parser(str(ExplainerMethod.tree_shap))
addCommonParserArgs(parser_tree_shap)

parser_tree_shap.add_argument(
"--interactions",
type=str2bool,
action=GroupedAction,
dest="explainer.interactions",
default=argparse.SUPPRESS,
)

parser_tree_shap.add_argument(
"--approximate",
type=str2bool,
action=GroupedAction,
dest="explainer.approximate",
default=argparse.SUPPRESS,
)

parser_tree_shap.add_argument(
"--check_additivity",
type=str2bool,
action=GroupedAction,
dest="explainer.check_additivity",
default=argparse.SUPPRESS,
)

parser_tree_shap.add_argument(
"--tree_limit",
type=int,
action=GroupedAction,
dest="explainer.tree_limit",
default=argparse.SUPPRESS,
)

parser_tree_shap.add_argument(
"--summarise_result",
type=str2bool,
action=GroupedAction,
dest="explainer.summarise_result",
default=argparse.SUPPRESS,
)

args, _ = parser.parse_known_args(sys_args)

argdDict = vars(args).copy()
Expand Down
28 changes: 28 additions & 0 deletions components/alibi-explain-server/alibiexplainer/tree_shap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import logging
import numpy as np
import alibi
from alibi.api.interfaces import Explanation
from alibiexplainer.explainer_wrapper import ExplainerWrapper
from alibiexplainer.constants import SELDON_LOGLEVEL
from typing import List, Optional

logging.basicConfig(level=SELDON_LOGLEVEL)


class TreeShap(ExplainerWrapper):
def __init__(
self,
explainer: Optional[alibi.explainers.TreeShap],
**kwargs
):
if explainer is None:
raise Exception("Tree Shap requires a built explainer")
self.tree_shap = explainer
self.kwargs = kwargs

def explain(self, inputs: List) -> Explanation:
arr = np.array(inputs)
logging.info("Tree Shap call with %s", self.kwargs)
logging.info("kernel shap data shape %s",arr.shape)
shap_exp = self.tree_shap.explain(arr, **self.kwargs)
return shap_exp
4 changes: 3 additions & 1 deletion components/alibi-explain-server/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@
"requests>=2.22.0",
"joblib>=0.13.2",
"dill>=0.3.0",
"grpcio>=1.22.0"
"grpcio>=1.22.0",
"xgboost==1.0.2",
"shap==0.35.0"
],
tests_require=tests_require,
extras_require={'test': tests_require}
Expand Down
25 changes: 25 additions & 0 deletions components/alibi-explain-server/tests/test_tree_shap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from alibiexplainer.tree_shap import TreeShap
import kfserving
import os
import dill
from alibi.datasets import fetch_adult
import numpy as np
import json
ADULT_EXPLAINER_URI = "gs://seldon-models/xgboost/adult/tree_shap_py36_0.5.2"
EXPLAINER_FILENAME = "explainer.dill"


def test_kernel_shap():
os.environ.clear()
alibi_model = os.path.join(
kfserving.Storage.download(ADULT_EXPLAINER_URI), EXPLAINER_FILENAME
)
with open(alibi_model, "rb") as f:
alibi_model = dill.load(f)
tree_shap = TreeShap(alibi_model)
adult = fetch_adult()
X_test = adult.data[30001:, :]
np.random.seed(0)
explanation = tree_shap.explain(X_test[0:1].tolist())
exp_json = json.loads(explanation.to_json())
print(exp_json)
Loading

0 comments on commit 094bf25

Please sign in to comment.