-
Notifications
You must be signed in to change notification settings - Fork 38
/
Copy pathlrp.py
116 lines (89 loc) · 5.27 KB
/
lrp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import numpy as np
from .abc_interpreter import Interpreter
from ..data_processor.readers import images_transform_pipeline, preprocess_save_path
from ..data_processor.visualizer import explanation_to_vis, show_vis_explanation, save_image
class LRPCVInterpreter(Interpreter):
"""
Layer-wise Relevance Propagation (LRP) Interpreter for CV tasks.
The detailed introduction of LRP can be found in the tutorial. Layer-wise Relevance Propagation (LRP) is an
explanation technique applicable to models structured as neural networks, where inputs can be e.g. images, videos,
or text. LRP operates by propagating the prediction backwards in the neural network, by means of purposely designed
local propagation rules.
Note that LRP requires ``model`` have :py:func:`relprop` and related implementations, see
`tutorial/assets/lrp_model <https://github.com/PaddlePaddle/InterpretDL/tree/master/tutorials/assets/lrp_model>`_.
This is different from other interpreters, which do not have additional requirements for ``model``.
More details regarding the LRP method can be found in the original paper:
https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0130140.
"""
def __init__(self, model: callable, device: str = 'gpu:0') -> None:
"""
Args:
model (callable): A model with :py:func:`forward` and possibly :py:func:`backward` functions.
device (str): The device used for running ``model``, options: ``"cpu"``, ``"gpu:0"``, ``"gpu:1"``
etc.
"""
Interpreter.__init__(self, model, device)
self.paddle_prepared = False
def interpret(self, inputs, label=None, resize_to=224, crop_to=None, visual=True, save_path=None):
"""
The difficulty for LRP implementation does not reside the algorithm, but the model. The model should be
implemented with :py:func:`relprop` functions, and the algorithm calls the relevance back-propagation, until
the input layer, where the final explanation is obtained.
Args:
inputs (str or list of strs or numpy.ndarray): The input image filepath or a list of filepaths or numpy
array of read images.
labels (list or tuple or numpy.ndarray, optional): The target labels to analyze. The number of labels
should be equal to the number of images. If None, the most likely label for each image will be used.
Default: ``None``.
resize_to (int, optional): Images will be rescaled with the shorter edge being ``resize_to``. Defaults to
``224``.
crop_to (int, optional): After resize, images will be center cropped to a square image with the size
``crop_to``. If None, no crop will be performed. Defaults to ``None``.
visual (bool, optional): Whether or not to visualize the processed image. Default: ``True``.
save_path (str, optional): The filepath(s) to save the processed image(s). If None, the image will not be
saved. Default: ``None``.
Returns:
[numpy.ndarray]: interpretations/Relevance map for images.
"""
imgs, data = images_transform_pipeline(inputs, resize_to, crop_to)
bsz = len(data)
save_path = preprocess_save_path(save_path, bsz)
if not self.paddle_prepared:
self._paddle_prepare()
R, output = self.predict_fn(data, label)
# visualization and save image.
for i in range(len(imgs)):
vis_explanation = explanation_to_vis(imgs[i], R[i].squeeze(), style='overlay_grayscale')
if visual:
show_vis_explanation(vis_explanation)
if save_path[i] is not None:
save_image(save_path[i], vis_explanation)
return R
def _paddle_prepare(self, predict_fn=None):
if predict_fn is None:
import paddle
paddle.set_device(self.device)
self.model.eval()
layer_list = [(n, v) for n, v in self.model.named_sublayers()]
num_classes = layer_list[-1][1].weight.shape[1]
def predict_fn(data, label):
data = paddle.to_tensor(data, stop_gradient=False)
output = self.model(data)
if label is None:
T = int(output.argmax())
else:
assert isinstance(label, int), "label should be an integer"
assert 0 <= label < num_classes, f"input label is not correct, label should be at [0, {num_classes})"
T = label
T = np.expand_dims(T, 0)
T = (T[:, np.newaxis] == np.arange(num_classes)) * 1.0
T = paddle.to_tensor(T).astype('float32')
R = self.model.relprop(R=output * T, alpha=1).sum(axis=1, keepdim=True)
# Check relevance value preserved
# print("Check relevance value preserved: ")
# print('Pred logit : ' + str((output * T).sum().cpu().numpy()))
# print('Relevance Sum : ' + str(R.sum().cpu().numpy()))
return R.detach().numpy(), output.detach().numpy()
self.predict_fn = predict_fn
self.paddle_prepared = True
return predict_fn