-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexample_usage.py
95 lines (68 loc) · 3.78 KB
/
example_usage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# %% General imports
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.pipeline import Pipeline
from text_explainability.data import from_string, import_data, train_test_split
from text_explainability.model import import_model
# %% Create train/test dataset
env = import_data('./datasets/test.csv', data_cols='fulltext', label_cols='label')
env = train_test_split(env, train_size=0.70)
# %% Create sklearn model with pipeline
pipeline = Pipeline([('tfidf', TfidfVectorizer(use_idf=True)),
('rf', RandomForestClassifier(random_state=0))])
# %% Wrap sklearn model
model = import_model(pipeline, environment=env)
# %% Imports
from text_explainability.data import from_string
from text_explainability.data.augmentation import LeaveOut, TokenReplacement
from text_explainability.global_explanation import (KMedoids, LabelwiseMMDCritic, MMDCritic, TokenFrequency,
TokenInformation)
from text_explainability.local_explanation import BayLIME, LIME, Anchor, FoilTree, KernelSHAP, LocalRules, LocalTree
from text_explainability.utils import PUNCTUATION, default_detokenizer
# %% Create example instance
sample = from_string('Dit is zeer positieve of negatieve proef... Of toch negatief?')
# %%
repl = TokenReplacement(env, default_detokenizer)
# %% Sequential replacement, 10 samples
print(list(repl(sample, n_samples=10).all_data()))
# %% Non-sequential replacement, 10 samples
print(list(repl(sample, n_samples=10, sequential=False).all_data()))
# %% Non-sequential, contiguous replacement, 10 samples
print(list(repl(sample, n_samples=10, sequential=False, contiguous=True).all_data()))
# %% Sequential deletion, 10 samples
print(list(LeaveOut(env, default_detokenizer)(sample, n_samples=10).all_data()))
# %% LIME explainer for `sample` on `model`
explainer = LIME(env)
explainer(sample, model, labels=['neutraal', 'positief']).scores
# %% Local tree explainer for `sample` on `model` (non-weighted neighborhood data)
LocalTree()(sample, model, weigh_samples=False).rules
# %% SHAP explanation for `sample` on `model`, limited to 4 features
KernelSHAP(labelset=env.labels)(sample, model, n_samples=100, l1_reg=4)
# %% Bayesian extension of LIME with 1000 samples
BayLIME()(sample, model, n_samples=1000)
# %% Anchor explanation for `sample` on `model`
#Anchor(label_names=['neg', 'net', 'pos'])(sample, model)
# %% FoilTree explanation for `sample` on `model` (why not 'neg'?)
FoilTree()(sample, model, 'positief').rules
# %% LocalRules on `model` (why 'positief'?)
LocalRules()(sample, model, 'negatief', n_samples=100).rules
# %% Global word frequency explanation on ground-truth labels
tf = TokenFrequency(env['train'])
tf(labelprovider=env.labels, explain_model=False, k=10).scores
# %% Global word frequency explanation on model predictions
tf(model=model, explain_model=True, k=3, filter_words=PUNCTUATION)
# %% Token information for dataset
ti = TokenInformation(env['train'])
ti(labelprovider=env.labels, explain_model=False, k=25).scores
# %% Token information for model
ti(model=model, explain_model=True, k=25, filter_words=PUNCTUATION)
# %% Extract top-2 prototypes with KMedoids
KMedoids(env.dataset).prototypes(n=2)
# %% Extract top-2 prototypes and top-2 criticisms label with MMDCritic
MMDCritic(env.dataset)(n_prototypes=2, n_criticisms=2)
# %% Extract 1 prototype for each ground-truth label with MMDCritic
LabelwiseMMDCritic(env.dataset, env.labels).prototypes(n=1)
# %% Extract 1 prototype and 2 criticisms for each ground-truth label with MMDCritic
LabelwiseMMDCritic(env.dataset, env.labels)(n_prototypes=1, n_criticisms=2)
# %% Extract 1 prototype and 1 criticism for each predicted label with MMDCritic
LabelwiseMMDCritic(env.dataset, model)(n_prototypes=1, n_criticisms=1)