This repository has been archived by the owner on Jan 5, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 27
/
pytorch.py
152 lines (119 loc) · 5.37 KB
/
pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from collections import OrderedDict
import logging
import numpy as np
from PIL import Image
from model_tools.activations.core import ActivationsExtractorHelper
from model_tools.utils import fullname
SUBMODULE_SEPARATOR = '.'
class PytorchWrapper:
def __init__(self, model, preprocessing, identifier=None, forward_kwargs=None, *args, **kwargs):
import torch
logger = logging.getLogger(fullname(self))
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.debug(f"Using device {self._device}")
self._model = model
self._model = self._model.to(self._device)
identifier = identifier or model.__class__.__name__
self._extractor = self._build_extractor(
identifier=identifier, preprocessing=preprocessing, get_activations=self.get_activations, *args, **kwargs)
self._extractor.insert_attrs(self)
self._forward_kwargs = forward_kwargs or {}
def _build_extractor(self, identifier, preprocessing, get_activations, *args, **kwargs):
return ActivationsExtractorHelper(
identifier=identifier, get_activations=get_activations, preprocessing=preprocessing,
*args, **kwargs)
@property
def identifier(self):
return self._extractor.identifier
@identifier.setter
def identifier(self, value):
self._extractor.identifier = value
def __call__(self, *args, **kwargs): # cannot assign __call__ as attribute due to Python convention
return self._extractor(*args, **kwargs)
def get_activations(self, images, layer_names):
import torch
from torch.autograd import Variable
images = [torch.from_numpy(image) if not isinstance(image, torch.Tensor) else image for image in images]
images = Variable(torch.stack(images))
images = images.to(self._device)
self._model.eval()
layer_results = OrderedDict()
hooks = []
for layer_name in layer_names:
layer = self.get_layer(layer_name)
hook = self.register_hook(layer, layer_name, target_dict=layer_results)
hooks.append(hook)
with torch.no_grad():
self._model(images, **self._forward_kwargs)
for hook in hooks:
hook.remove()
return layer_results
def get_layer(self, layer_name):
if layer_name == 'logits':
return self._output_layer()
module = self._model
for part in layer_name.split(SUBMODULE_SEPARATOR):
module = module._modules.get(part)
assert module is not None, f"No submodule found for layer {layer_name}, at part {part}"
return module
def _output_layer(self):
module = self._model
while module._modules:
module = module._modules[next(reversed(module._modules))]
return module
@classmethod
def _tensor_to_numpy(cls, output):
return output.cpu().data.numpy()
def register_hook(self, layer, layer_name, target_dict):
def hook_function(_layer, _input, output, name=layer_name):
target_dict[name] = PytorchWrapper._tensor_to_numpy(output)
hook = layer.register_forward_hook(hook_function)
return hook
def __repr__(self):
return repr(self._model)
def layers(self):
for name, module in self._model.named_modules():
if len(list(module.children())) > 0: # this module only holds other modules
continue
yield name, module
def graph(self):
import networkx as nx
g = nx.DiGraph()
for layer_name, layer in self.layers():
g.add_node(layer_name, object=layer, type=type(layer))
return g
def load_preprocess_images(image_filepaths, image_size, **kwargs):
images = load_images(image_filepaths)
images = preprocess_images(images, image_size=image_size, **kwargs)
return images
def load_images(image_filepaths):
return [load_image(image_filepath) for image_filepath in image_filepaths]
def load_image(image_filepath):
with Image.open(image_filepath) as pil_image:
if 'L' not in pil_image.mode.upper() and 'A' not in pil_image.mode.upper() \
and 'P' not in pil_image.mode.upper(): # not binary and not alpha and not palletized
# work around to https://github.com/python-pillow/Pillow/issues/1144,
# see https://stackoverflow.com/a/30376272/2225200
return pil_image.copy()
else: # make sure potential binary images are in RGB
rgb_image = Image.new("RGB", pil_image.size)
rgb_image.paste(pil_image)
return rgb_image
def preprocess_images(images, image_size, **kwargs):
preprocess = torchvision_preprocess_input(image_size, **kwargs)
images = [preprocess(image) for image in images]
images = np.concatenate(images)
return images
def torchvision_preprocess_input(image_size, **kwargs):
from torchvision import transforms
return transforms.Compose([
transforms.Resize((image_size, image_size)),
torchvision_preprocess(**kwargs),
])
def torchvision_preprocess(normalize_mean=(0.485, 0.456, 0.406), normalize_std=(0.229, 0.224, 0.225)):
from torchvision import transforms
return transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=normalize_mean, std=normalize_std),
lambda img: img.unsqueeze(0)
])