-
Notifications
You must be signed in to change notification settings - Fork 130
/
model.py
399 lines (338 loc) · 13.8 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
"""
For external use as part of nbdt package. This is a model that
runs inference as an NBDT. Note these make no assumption about the
underlying neural network other than it (1) is a classification model and
(2) returns logits.
"""
import torch.nn as nn
from nbdt.utils import (
dataset_to_default_path_graph,
dataset_to_default_path_wnids,
hierarchy_to_path_graph,
coerce_tensor,
uncoerce_tensor,
)
from nbdt.models.utils import load_state_dict_from_key, coerce_state_dict
from nbdt.tree import Node, Tree
from nbdt.thirdparty.wn import get_wnids, synset_to_name
from nbdt.thirdparty.nx import get_root
from torch.distributions import Categorical
import torch
import torch.nn as nn
import torch.nn.functional as F
model_urls = {
(
"ResNet18",
"CIFAR10",
): "https://github.com/alvinwan/neural-backed-decision-trees/releases/download/0.0.1/ckpt-CIFAR10-ResNet18-induced-ResNet18-SoftTreeSupLoss.pth",
(
"wrn28_10_cifar10",
"CIFAR10",
): "https://github.com/alvinwan/neural-backed-decision-trees/releases/download/0.0.1/ckpt-CIFAR10-wrn28_10_cifar10-induced-wrn28_10_cifar10-SoftTreeSupLoss.pth",
(
"wrn28_10_cifar10",
"CIFAR10",
"wordnet",
): "https://github.com/alvinwan/neural-backed-decision-trees/releases/download/0.0.1/ckpt-CIFAR10-wrn28_10_cifar10-wordnet-SoftTreeSupLoss.pth",
(
"ResNet18",
"CIFAR100",
): "https://github.com/alvinwan/neural-backed-decision-trees/releases/download/0.0.1/ckpt-CIFAR100-ResNet18-induced-ResNet18-SoftTreeSupLoss.pth",
(
"wrn28_10_cifar100",
"CIFAR100",
): "https://github.com/alvinwan/neural-backed-decision-trees/releases/download/0.0.1/ckpt-CIFAR100-wrn28_10_cifar100-induced-wrn28_10_cifar100-SoftTreeSupLoss.pth",
(
"ResNet18",
"TinyImagenet200",
): "https://github.com/alvinwan/neural-backed-decision-trees/releases/download/0.0.1/ckpt-TinyImagenet200-ResNet18-induced-ResNet18-SoftTreeSupLoss-tsw10.0.pth",
(
"wrn28_10",
"TinyImagenet200",
): "https://github.com/alvinwan/neural-backed-decision-trees/releases/download/0.0.1/ckpt-TinyImagenet200-wrn28_10-induced-wrn28_10-SoftTreeSupLoss-tsw10.0.pth",
}
#########
# RULES #
#########
class EmbeddedDecisionRules(nn.Module):
def __init__(
self,
dataset=None,
path_graph=None,
path_wnids=None,
classes=(),
hierarchy=None,
tree=None,
):
super().__init__()
if not tree:
tree = Tree(dataset, path_graph, path_wnids, classes, hierarchy=hierarchy)
self.tree = tree
self.correct = 0
self.total = 0
self.I = torch.eye(len(self.tree.classes))
@staticmethod
def get_node_logits(outputs, node=None, new_to_old_classes=None, num_classes=None):
"""Get output for a particular node
This `outputs` above are the output of the neural network.
"""
assert node or (
new_to_old_classes and num_classes
), "Either pass node or (new_to_old_classes mapping and num_classes)"
new_to_old_classes = new_to_old_classes or node.child_index_to_class_index
num_classes = num_classes or node.num_classes
return torch.stack(
[
outputs.T[new_to_old_classes[child_index]].mean(dim=0)
for child_index in range(num_classes)
]
).T
@classmethod
def get_all_node_outputs(cls, outputs, nodes):
"""Run hard embedded decision rules.
Returns the output for *every single node.
"""
wnid_to_outputs = {}
for node in nodes:
node_logits = cls.get_node_logits(outputs, node)
node_outputs = {"logits": node_logits}
if len(node_logits.size()) > 1:
node_outputs["preds"] = torch.max(node_logits, dim=1)[1]
node_outputs["probs"] = F.softmax(node_logits, dim=1)
node_outputs["entropy"] = Categorical(
probs=node_outputs["probs"]
).entropy()
wnid_to_outputs[node.wnid] = node_outputs
return wnid_to_outputs
def forward_nodes(self, outputs):
return self.get_all_node_outputs(outputs, self.tree.inodes)
class HardEmbeddedDecisionRules(EmbeddedDecisionRules):
@classmethod
def get_node_logits_filtered(cls, node, outputs, targets):
"""'Smarter' inference for a hard node.
If you have targets for the node, you can selectively perform inference,
only for nodes where the label of a sample is well-defined.
"""
classes = [node.class_index_to_child_index[int(t)] for t in targets]
selector = [bool(cls) for cls in classes]
targets_sub = [cls[0] for cls in classes if cls]
outputs = outputs[selector]
if outputs.size(0) == 0:
return selector, outputs[:, : node.num_classes], targets_sub
outputs_sub = cls.get_node_logits(outputs, node)
return selector, outputs_sub, targets_sub
@classmethod
def traverse_tree(cls, wnid_to_outputs, tree):
"""Convert node outputs to final prediction.
Note that the prediction output for this function can NOT be trained
on. The outputs have been detached from the computation graph.
"""
# move all to CPU, detach from computation graph
example = wnid_to_outputs[tree.inodes[0].wnid]
n_samples = int(example["logits"].size(0))
device = example["logits"].device
for wnid in tuple(wnid_to_outputs.keys()):
outputs = wnid_to_outputs[wnid]
outputs["preds"] = list(map(int, outputs["preds"].cpu()))
outputs["probs"] = outputs["probs"].detach().cpu()
decisions = []
preds = []
for index in range(n_samples):
decision = [{"node": tree.root, "name": "root", "prob": 1, "entropy": 0}]
node = tree.root
while not node.is_leaf():
if node.wnid not in wnid_to_outputs:
node = None
break
outputs = wnid_to_outputs[node.wnid]
index_child = outputs["preds"][index]
prob_child = float(outputs["probs"][index][index_child])
node = node.children[index_child]
decision.append(
{
"node": node,
"name": node.name,
"prob": prob_child,
"next_index": index_child,
"entropy": float(outputs["entropy"][index]),
}
)
preds.append(tree.wnid_to_class_index[node.wnid])
decisions.append(decision)
return torch.Tensor(preds).long().to(device), decisions
def predicted_to_logits(self, predicted):
"""Convert predicted classes to one-hot logits."""
if self.I.device != predicted.device:
self.I = self.I.to(predicted.device)
return self.I[predicted]
def forward_with_decisions(self, outputs):
wnid_to_outputs = self.forward_nodes(outputs)
predicted, decisions = self.traverse_tree(wnid_to_outputs, self.tree)
logits = self.predicted_to_logits(predicted)
logits._nbdt_output_flag = True # checked in nbdt losses, to prevent mistakes
return logits, decisions
def forward(self, outputs):
outputs, _ = self.forward_with_decisions(outputs)
return outputs
class SoftEmbeddedDecisionRules(EmbeddedDecisionRules):
@classmethod
def traverse_tree(cls, wnid_to_outputs, tree):
"""
In theory, the loop over children below could be replaced with just a
few lines:
for index_child in range(len(node.children)):
old_indexes = node.child_index_to_class_index[index_child]
class_probs[:,old_indexes] *= output[:,index_child][:,None]
However, we collect all indices first, so that only one tensor operation
is run. The output is a single distribution over all leaves. The
ordering is determined by the original ordering of the provided logits.
(I think. Need to check nbdt.data.custom.Node)
"""
example = wnid_to_outputs[tree.inodes[0].wnid]
num_samples = example["logits"].size(0)
num_classes = len(tree.classes)
device = example["logits"].device
class_probs = torch.ones((num_samples, num_classes)).to(device)
for node in tree.inodes:
outputs = wnid_to_outputs[node.wnid]
old_indices, new_indices = [], []
for index_child in range(len(node.children)):
old = node.child_index_to_class_index[index_child]
old_indices.extend(old)
new_indices.extend([index_child] * len(old))
assert len(set(old_indices)) == len(old_indices), (
"All old indices must be unique in order for this operation "
"to be correct."
)
class_probs[:, old_indices] *= outputs["probs"][:, new_indices]
return class_probs
def forward_with_decisions(self, outputs):
wnid_to_outputs = self.forward_nodes(outputs)
outputs = self.forward(outputs, wnid_to_outputs)
_, predicted = outputs.max(1)
decisions = []
node = self.tree.inodes[0]
leaf_to_steps = self.tree.get_leaf_to_steps()
for index, prediction in enumerate(predicted):
leaf = self.tree.wnids_leaves[prediction]
steps = leaf_to_steps[leaf]
probs = [1]
entropies = [0]
for step in steps[:-1]:
_out = wnid_to_outputs[step["node"].wnid]
_probs = _out["probs"][0]
probs.append(_probs[step["next_index"]])
entropies.append(Categorical(probs=_probs).entropy().item())
for step, prob, entropy in zip(steps, probs, entropies):
step["prob"] = float(prob)
step["entropy"] = float(entropy)
decisions.append(steps)
return outputs, decisions
def forward(self, outputs, wnid_to_outputs=None):
if not wnid_to_outputs:
wnid_to_outputs = self.forward_nodes(outputs)
logits = self.traverse_tree(wnid_to_outputs, self.tree)
logits._nbdt_output_flag = True # checked in nbdt losses, to prevent mistakes
return logits
##########
# MODELS #
##########
class NBDT(nn.Module):
def __init__(
self,
dataset,
model,
arch=None,
path_graph=None,
path_wnids=None,
classes=None,
hierarchy=None,
pretrained=None,
**kwargs,
):
super().__init__()
if dataset and not hierarchy and not path_graph:
assert arch, "Must specify `arch` if no `hierarchy` or `path_graph`"
hierarchy = f"induced-{arch}"
if pretrained and not arch:
raise UserWarning(
"To load a pretrained NBDT, you need to specify the `arch`. "
"`arch` is the name of the architecture. e.g., ResNet18"
)
if isinstance(model, str):
raise NotImplementedError("Model must be nn.Module")
tree = Tree(dataset, path_graph, path_wnids, classes, hierarchy=hierarchy)
self.init(
dataset,
model,
tree,
arch=arch,
pretrained=pretrained,
hierarchy=hierarchy,
**kwargs,
)
def init(
self,
dataset,
model,
tree,
arch=None,
pretrained=False,
hierarchy=None,
eval=True,
Rules=HardEmbeddedDecisionRules,
):
"""
Extra init method makes clear which arguments are finally necessary for
this class to function. The constructor for this class may generate
some of these required arguments if initially missing.
"""
self.rules = Rules(tree=tree)
self.model = model
if pretrained:
assert arch is not None
keys = [(arch, dataset), (arch, dataset, hierarchy)]
state_dict = load_state_dict_from_key(keys, model_urls, pretrained=True)
self.load_state_dict(state_dict)
if eval:
self.eval()
def load_state_dict(self, state_dict, **kwargs):
state_dict = coerce_state_dict(state_dict, self.model.state_dict())
return self.model.load_state_dict(state_dict, **kwargs)
def state_dict(self, *args, **kwargs):
return self.model.state_dict(*args, **kwargs)
def forward(self, x):
x = self.model(x)
x = self.rules(x)
return x
def forward_with_decisions(self, x):
x = self.model(x)
x, decisions = self.rules.forward_with_decisions(x)
return x, decisions
class HardNBDT(NBDT):
def __init__(self, *args, **kwargs):
kwargs.update({"Rules": HardEmbeddedDecisionRules})
super().__init__(*args, **kwargs)
class SoftNBDT(NBDT):
def __init__(self, *args, **kwargs):
kwargs.update({"Rules": SoftEmbeddedDecisionRules})
super().__init__(*args, **kwargs)
class SegNBDT(NBDT):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x):
assert len(x.shape) == 4, "Input must be of shape (N,C,H,W) for segmentation"
x = self.model(x)
original_shape = x.shape
x = coerce_tensor(x)
x = self.rules.forward(x)
x = uncoerce_tensor(x, original_shape)
return x
class HardSegNBDT(SegNBDT):
def __init__(self, *args, **kwargs):
kwargs.update({"Rules": HardEmbeddedDecisionRules})
super().__init__(*args, **kwargs)
class SoftSegNBDT(SegNBDT):
def __init__(self, *args, **kwargs):
kwargs.update({"Rules": SoftEmbeddedDecisionRules})
super().__init__(*args, **kwargs)