-
Notifications
You must be signed in to change notification settings - Fork 0
/
shap_bipartite.py
136 lines (117 loc) · 4.78 KB
/
shap_bipartite.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from scipy.special import comb, binom
from sklearn import linear_model
import itertools
import random
import numpy as np
from numpy.linalg import inv
class shap_bipartite():
def __init__(self, model, D, N, num_samples):
self.model = model
# self.data = data
# self.D = len(self.data[0])
# self.N = len(self.data)
self.D = D
self.N = N
self.num_samples = num_samples
self.samples = []
self.weights = []
self._weights = dict()
self._shap = None
# Generate samples
self.generate_samples()
def _get_mask(self, mask_inds):
mask = [0.] * self.D
for ind in mask_inds:
mask[ind] = 1.
return mask
def _get_weights(self, mask_inds):
mask_len = len(mask_inds)
if mask_len not in self._weights:
# Weights from paper. Using approximate combinations method using binom
self._weights[mask_len] = (self.D-1)/(binom(self.D, mask_len)*mask_len*(self.D-mask_len))
return self._weights[mask_len]
def _update(self, mask_inds):
self.samples.append(self._get_mask(mask_inds))
self.weights.append(self._get_weights(mask_inds))
def generate_samples(self):
while len(self.samples) < self.num_samples:
elem = random.randint(1, (2 ** self.D)-2)
mask = []
# Convert number to binary
j = 1
for i in range(self.D):
mask.append((elem // j) %2)
j *= 2
count = sum(mask)
if count not in self._weights:
# self._weights[count] = (self.D-1)/(binom(self.D, count)*count*(self.D-count))
self._weights[count] = (self.D-1)/(comb(self.D, count, exact=True)*count*(self.D-count))
self.weights.append(self._weights[count])
self.samples.append(mask)
def generate_samples_old(self):
# We first add sets of size 1 and D,
r = 1
remaining_samples = self.num_samples
ind_elems = list(range(self.D))
while len(self.samples) < self.num_samples:
if self.D % 2 == 0 and r == int(self.D/2):
# Middle element
if remaining_samples >= int(comb(self.D, r, exact=True)):
# Samples are greater than 2^D i.e. all masks
for mask_inds in itertools.combinations(ind_elems, r):
self._update(mask_inds)
remaining_samples -= 1
else:
for mask_inds in itertools.combinations(ind_elems, r):
if remaining_samples <= 0:
break
self._update(mask_inds)
remaining_samples -= 1
elif 2 * comb(self.D, r, exact=True) <= remaining_samples:
for mask_inds in itertools.combinations(ind_elems, r):
self._update(mask_inds)
remaining_samples -= 1
for mask_inds in itertools.combinations(ind_elems, self.D - r):
self._update(mask_inds)
remaining_samples -= 1
else:
for mask_inds in itertools.combinations(ind_elems, r):
if remaining_samples <= 0:
break
self._update(mask_inds)
remaining_samples -= 1
for mask_inds in itertools.combinations(ind_elems, self.D - r):
if remaining_samples <= 0:
break
self._update(mask_inds)
remaining_samples -= 1
def solve(self):
# If already solved return value
if self._shap is not None:
return self._shap
# Obtain the result of the function for each mask
results = []
# self.model : mask -> real
for samp in self.samples:
results.append(self.model(samp))
# Use linear regression
# regr = linear_model.LinearRegression(fit_intercept=False)
regr = linear_model.Lasso(fit_intercept=False)
regr.fit(self.samples, results, self.weights)
# The coefficients of the linear regression line correspond to the
# Shapley values
self._shap = regr.coef_
return regr.coef_
def solve_lin_alg(self):
# Obtain the result of the function for each mask
results = []
# self.model : mask -> real
for samp in self.samples:
results.append(self.model(samp))
X = np.matrix(self.samples)
W = np.diag(self.weights)
y = results
xW = np.dot(np.transpose(X), W)
inverted = inv(np.dot(xW, X))
shapley = np.dot(np.dot(inverted, xW), y)
return shapley