-
-
Notifications
You must be signed in to change notification settings - Fork 211
/
gradient.py
407 lines (326 loc) · 16.5 KB
/
gradient.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
# Copyright (C) 2020-2023, François-Guillaume Fernandez.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.
from functools import partial
from typing import Any, List, Optional, Tuple, Union, cast
import torch
from torch import Tensor, nn
from .core import _CAM
__all__ = ["GradCAM", "GradCAMpp", "SmoothGradCAMpp", "XGradCAM", "LayerCAM"]
class _GradCAM(_CAM):
"""Implements a gradient-based class activation map extractor.
Args:
model: input model
target_layer: either the target layer itself or its name, or a list of those
input_shape: shape of the expected input tensor excluding the batch dimension
"""
def __init__(
self,
model: nn.Module,
target_layer: Optional[Union[Union[nn.Module, str], List[Union[nn.Module, str]]]] = None,
input_shape: Tuple[int, ...] = (3, 224, 224),
**kwargs: Any,
) -> None:
super().__init__(model, target_layer, input_shape, **kwargs)
# Ensure ReLU is applied before normalization
self._relu = True
# Model output is used by the extractor
self._score_used = True
for idx, name in enumerate(self.target_names):
# Trick to avoid issues with inplace operations cf. https://github.com/pytorch/pytorch/issues/61519
self.hook_handles.append(self.submodule_dict[name].register_forward_hook(partial(self._hook_g, idx=idx)))
def _store_grad(self, grad: Tensor, idx: int = 0) -> None:
if self._hooks_enabled:
self.hook_g[idx] = grad.data
def _hook_g(self, _: nn.Module, _input: Tuple[Tensor, ...], output: Tensor, idx: int = 0) -> None:
"""Gradient hook"""
if self._hooks_enabled:
self.hook_handles.append(output.register_hook(partial(self._store_grad, idx=idx)))
def _backprop(
self,
scores: Tensor,
class_idx: Union[int, List[int]],
retain_graph: bool = False,
) -> None:
"""Backpropagate the loss for a specific output class"""
# Backpropagate to get the gradients on the hooked layer
if isinstance(class_idx, int):
loss = scores[:, class_idx].sum()
else:
loss = scores.gather(1, torch.tensor(class_idx, device=scores.device).view(-1, 1)).sum()
self.model.zero_grad()
loss.backward(retain_graph=retain_graph)
class GradCAM(_GradCAM):
r"""Implements a class activation map extractor as described in `"Grad-CAM: Visual Explanations from Deep Networks
via Gradient-based Localization" <https://arxiv.org/pdf/1610.02391.pdf>`_.
The localization map is computed as follows:
.. math::
L^{(c)}_{Grad-CAM}(x, y) = ReLU\Big(\sum\limits_k w_k^{(c)} A_k(x, y)\Big)
with the coefficient :math:`w_k^{(c)}` being defined as:
.. math::
w_k^{(c)} = \frac{1}{H \cdot W} \sum\limits_{i=1}^H \sum\limits_{j=1}^W
\frac{\partial Y^{(c)}}{\partial A_k(i, j)}
where :math:`A_k(x, y)` is the activation of node :math:`k` in the target layer of the model at
position :math:`(x, y)`,
and :math:`Y^{(c)}` is the model output score for class :math:`c` before softmax.
>>> from torchvision.models import resnet18
>>> from torchcam.methods import GradCAM
>>> model = resnet18(pretrained=True).eval()
>>> cam = GradCAM(model, 'layer4')
>>> scores = model(input_tensor)
>>> cam(class_idx=100, scores=scores)
Args:
model: input model
target_layer: either the target layer itself or its name, or a list of those
input_shape: shape of the expected input tensor excluding the batch dimension
"""
def _get_weights(self, class_idx: Union[int, List[int]], scores: Tensor, **kwargs: Any) -> List[Tensor]:
"""Computes the weight coefficients of the hooked activation maps."""
# Backpropagate
self._backprop(scores, class_idx, **kwargs)
self.hook_g: List[Tensor] # type: ignore[assignment]
# Global average pool the gradients over spatial dimensions
return [grad.flatten(2).mean(-1) for grad in self.hook_g]
class GradCAMpp(_GradCAM):
r"""Implements a class activation map extractor as described in `"Grad-CAM++: Improved Visual Explanations for
Deep Convolutional Networks" <https://arxiv.org/pdf/1710.11063.pdf>`_.
The localization map is computed as follows:
.. math::
L^{(c)}_{Grad-CAM++}(x, y) = \sum\limits_k w_k^{(c)} A_k(x, y)
with the coefficient :math:`w_k^{(c)}` being defined as:
.. math::
w_k^{(c)} = \sum\limits_{i=1}^H \sum\limits_{j=1}^W \alpha_k^{(c)}(i, j) \cdot
ReLU\Big(\frac{\partial Y^{(c)}}{\partial A_k(i, j)}\Big)
where :math:`A_k(x, y)` is the activation of node :math:`k` in the target layer of the model at
position :math:`(x, y)`,
:math:`Y^{(c)}` is the model output score for class :math:`c` before softmax,
and :math:`\alpha_k^{(c)}(i, j)` being defined as:
.. math::
\alpha_k^{(c)}(i, j) = \frac{1}{\sum\limits_{i, j} \frac{\partial Y^{(c)}}{\partial A_k(i, j)}}
= \frac{\frac{\partial^2 Y^{(c)}}{(\partial A_k(i,j))^2}}{2 \cdot
\frac{\partial^2 Y^{(c)}}{(\partial A_k(i,j))^2} + \sum\limits_{a,b} A_k (a,b) \cdot
\frac{\partial^3 Y^{(c)}}{(\partial A_k(i,j))^3}}
if :math:`\frac{\partial Y^{(c)}}{\partial A_k(i, j)} = 1` else :math:`0`.
>>> from torchvision.models import resnet18
>>> from torchcam.methods import GradCAMpp
>>> model = resnet18(pretrained=True).eval()
>>> cam = GradCAMpp(model, 'layer4')
>>> scores = model(input_tensor)
>>> cam(class_idx=100, scores=scores)
Args:
model: input model
target_layer: either the target layer itself or its name, or a list of those
input_shape: shape of the expected input tensor excluding the batch dimension
"""
def _get_weights(
self,
class_idx: Union[int, List[int]],
scores: Tensor,
eps: float = 1e-8,
**kwargs: Any,
) -> List[Tensor]:
"""Computes the weight coefficients of the hooked activation maps."""
# Backpropagate
self._backprop(scores, class_idx, **kwargs)
self.hook_a: List[Tensor] # type: ignore[assignment]
self.hook_g: List[Tensor] # type: ignore[assignment]
# Alpha coefficient for each pixel
grad_2 = [grad.pow(2) for grad in self.hook_g]
grad_3 = [g2 * grad for g2, grad in zip(grad_2, self.hook_g)]
# Watch out for NaNs produced by underflow
spatial_dims = self.hook_a[0].ndim - 2
denom = [
2 * g2 + (g3 * act).flatten(2).sum(-1)[(...,) + (None,) * spatial_dims]
for g2, g3, act in zip(grad_2, grad_3, self.hook_a)
]
nan_mask = [g2 > 0 for g2 in grad_2]
alpha = grad_2
for idx, d, mask in zip(range(len(grad_2)), denom, nan_mask):
alpha[idx][mask].div_(d[mask] + eps)
# Apply pixel coefficient in each weight
return [a.mul_(torch.relu(grad)).flatten(2).sum(-1) for a, grad in zip(alpha, self.hook_g)]
class SmoothGradCAMpp(_GradCAM):
r"""Implements a class activation map extractor as described in `"Smooth Grad-CAM++: An Enhanced Inference Level
Visualization Technique for Deep Convolutional Neural Network Models" <https://arxiv.org/pdf/1908.01224.pdf>`_
with a personal correction to the paper (alpha coefficient numerator).
The localization map is computed as follows:
.. math::
L^{(c)}_{Smooth Grad-CAM++}(x, y) = \sum\limits_k w_k^{(c)} A_k(x, y)
with the coefficient :math:`w_k^{(c)}` being defined as:
.. math::
w_k^{(c)} = \sum\limits_{i=1}^H \sum\limits_{j=1}^W \alpha_k^{(c)}(i, j) \cdot
ReLU\Big(\frac{\partial Y^{(c)}}{\partial A_k(i, j)}\Big)
where :math:`A_k(x, y)` is the activation of node :math:`k` in the target layer of the model at
position :math:`(x, y)`,
:math:`Y^{(c)}` is the model output score for class :math:`c` before softmax,
and :math:`\alpha_k^{(c)}(i, j)` being defined as:
.. math::
\alpha_k^{(c)}(i, j)
= \frac{\frac{\partial^2 Y^{(c)}}{(\partial A_k(i,j))^2}}{2 \cdot
\frac{\partial^2 Y^{(c)}}{(\partial A_k(i,j))^2} + \sum\limits_{a,b} A_k (a,b) \cdot
\frac{\partial^3 Y^{(c)}}{(\partial A_k(i,j))^3}}
= \frac{\frac{1}{n} \sum\limits_{m=1}^n D^{(c, 2)}_k(i, j)}{
\frac{2}{n} \sum\limits_{m=1}^n D^{(c, 2)}_k(i, j) + \sum\limits_{a,b} A_k (a,b) \cdot
\frac{1}{n} \sum\limits_{m=1}^n D^{(c, 3)}_k(i, j)}
if :math:`\frac{\partial Y^{(c)}}{\partial A_k(i, j)} = 1` else :math:`0`. Here :math:`D^{(c, p)}_k(i, j)`
refers to the p-th partial derivative of the class score of class :math:`c` relatively to the activation in layer
:math:`k` at position :math:`(i, j)`, and :math:`n` is the number of samples used to get the gradient estimate.
Please note the difference in the numerator of :math:`\alpha_k^{(c)}(i, j)`,
which is actually :math:`\frac{1}{n} \sum\limits_{k=1}^n D^{(c, 1)}_k(i,j)` in the paper.
>>> from torchvision.models import resnet18
>>> from torchcam.methods import SmoothGradCAMpp
>>> model = resnet18(pretrained=True).eval()
>>> cam = SmoothGradCAMpp(model, 'layer4')
>>> scores = model(input_tensor)
>>> cam(class_idx=100)
Args:
model: input model
target_layer: either the target layer itself or its name, or a list of those
num_samples: number of samples to use for smoothing
std: standard deviation of the noise
input_shape: shape of the expected input tensor excluding the batch dimension
"""
def __init__(
self,
model: nn.Module,
target_layer: Optional[Union[Union[nn.Module, str], List[Union[nn.Module, str]]]] = None,
num_samples: int = 4,
std: float = 0.3,
input_shape: Tuple[int, ...] = (3, 224, 224),
**kwargs: Any,
) -> None:
super().__init__(model, target_layer, input_shape, **kwargs)
# Model scores is not used by the extractor
self._score_used = False
# Input hook
self.hook_handles.append(model.register_forward_pre_hook(self._store_input)) # type: ignore[arg-type]
# Noise distribution
self.num_samples = num_samples
self.std = std
self._distrib = torch.distributions.normal.Normal(0, self.std)
# Specific input hook updater
self._ihook_enabled = True
def _store_input(self, _: nn.Module, _input: Tensor) -> None:
"""Store model input tensor."""
if self._ihook_enabled:
self._input = _input[0].data.clone()
def _get_weights(
self,
class_idx: Union[int, List[int]],
_: Union[Tensor, None] = None,
eps: float = 1e-8,
**kwargs: Any,
) -> List[Tensor]:
"""Computes the weight coefficients of the hooked activation maps."""
# Disable input update
self._ihook_enabled = False
# Keep initial activation
self.hook_a: List[Tensor] # type: ignore[assignment]
self.hook_g: List[Tensor] # type: ignore[assignment]
init_fmap = [act.clone() for act in self.hook_a]
# Initialize our gradient estimates
grad_2 = [torch.zeros_like(act) for act in self.hook_a]
grad_3 = [torch.zeros_like(act) for act in self.hook_a]
# Perform the operations N times
for _idx in range(self.num_samples):
# Add noise
noisy_input = self._input + self._distrib.sample(self._input.size()).to(device=self._input.device)
noisy_input.requires_grad_(True)
# Forward & Backward
out = self.model(noisy_input)
self.model.zero_grad()
self._backprop(out, class_idx, **kwargs)
# Sum partial derivatives
grad_2 = [g2.add_(grad.pow(2)) for g2, grad in zip(grad_2, self.hook_g)]
grad_3 = [g3.add_(grad.pow(3)) for g3, grad in zip(grad_3, self.hook_g)]
# Reenable input update
self._ihook_enabled = True
# Average the gradient estimates
grad_2 = [g2.div_(self.num_samples) for g2 in grad_2]
grad_3 = [g3.div_(self.num_samples) for g3 in grad_3]
# Alpha coefficient for each pixel
spatial_dims = self.hook_a[0].ndim - 2
alpha = [
g2 / (2 * g2 + (g3 * act).flatten(2).sum(-1)[(...,) + (None,) * spatial_dims] + eps)
for g2, g3, act in zip(grad_2, grad_3, init_fmap)
]
# Apply pixel coefficient in each weight
return [a.mul_(torch.relu(grad)).flatten(2).sum(-1) for a, grad in zip(alpha, self.hook_g)]
def extra_repr(self) -> str:
return f"target_layer={self.target_names}, num_samples={self.num_samples}, std={self.std}"
class XGradCAM(_GradCAM):
r"""Implements a class activation map extractor as described in `"Axiom-based Grad-CAM: Towards Accurate
Visualization and Explanation of CNNs" <https://arxiv.org/pdf/2008.02312.pdf>`_.
The localization map is computed as follows:
.. math::
L^{(c)}_{XGrad-CAM}(x, y) = ReLU\Big(\sum\limits_k w_k^{(c)} A_k(x, y)\Big)
with the coefficient :math:`w_k^{(c)}` being defined as:
.. math::
w_k^{(c)} = \sum\limits_{i=1}^H \sum\limits_{j=1}^W
\Big( \frac{\partial Y^{(c)}}{\partial A_k(i, j)} \cdot
\frac{A_k(i, j)}{\sum\limits_{m=1}^H \sum\limits_{n=1}^W A_k(m, n)} \Big)
where :math:`A_k(x, y)` is the activation of node :math:`k` in the target layer of the model at
position :math:`(x, y)`,
and :math:`Y^{(c)}` is the model output score for class :math:`c` before softmax.
>>> from torchvision.models import resnet18
>>> from torchcam.methods import XGradCAM
>>> model = resnet18(pretrained=True).eval()
>>> cam = XGradCAM(model, 'layer4')
>>> scores = model(input_tensor)
>>> cam(class_idx=100, scores=scores)
Args:
model: input model
target_layer: either the target layer itself or its name, or a list of those
input_shape: shape of the expected input tensor excluding the batch dimension
"""
def _get_weights(
self,
class_idx: Union[int, List[int]],
scores: Tensor,
eps: float = 1e-8,
**kwargs: Any,
) -> List[Tensor]:
"""Computes the weight coefficients of the hooked activation maps."""
# Backpropagate
self._backprop(scores, class_idx, **kwargs)
self.hook_a: List[Tensor] # type: ignore[assignment]
self.hook_g: List[Tensor] # type: ignore[assignment]
return [
(grad * act).flatten(2).sum(-1) / act.flatten(2).sum(-1).add(eps)
for act, grad in zip(self.hook_a, self.hook_g)
]
class LayerCAM(_GradCAM):
r"""Implements a class activation map extractor as described in `"LayerCAM: Exploring Hierarchical Class Activation
Maps for Localization" <http://mmcheng.net/mftp/Papers/21TIP_LayerCAM.pdf>`_.
The localization map is computed as follows:
.. math::
L^{(c)}_{Layer-CAM}(x, y) = ReLU\Big(\sum\limits_k w_k^{(c)}(x, y) \cdot A_k(x, y)\Big)
with the coefficient :math:`w_k^{(c)}(x, y)` being defined as:
.. math::
w_k^{(c)}(x, y) = ReLU\Big(\frac{\partial Y^{(c)}}{\partial A_k(i, j)}(x, y)\Big)
where :math:`A_k(x, y)` is the activation of node :math:`k` in the target layer of the model at
position :math:`(x, y)`,
and :math:`Y^{(c)}` is the model output score for class :math:`c` before softmax.
>>> from torchvision.models import resnet18
>>> from torchcam.methods import LayerCAM
>>> model = resnet18(pretrained=True).eval()
>>> extractor = LayerCAM(model, 'layer4')
>>> scores = model(input_tensor)
>>> cams = extractor(class_idx=100, scores=scores)
>>> fused_cam = extractor.fuse_cams(cams)
Args:
model: input model
target_layer: either the target layer itself or its name, or a list of those
input_shape: shape of the expected input tensor excluding the batch dimension
"""
def _get_weights(self, class_idx: Union[int, List[int]], scores: Tensor, **kwargs: Any) -> List[Tensor]:
"""Computes the weight coefficients of the hooked activation maps."""
# Backpropagate
self._backprop(scores, class_idx, **kwargs)
self.hook_g: List[Tensor] # type: ignore[assignment]
# List of (N, C, H, W)
return [torch.relu(grad) for grad in self.hook_g]
@staticmethod
def _scale_cams(cams: List[Tensor], gamma: float = 2.0) -> List[Tensor]:
# cf. Equation 9 in the paper
return [torch.tanh(cast(Tensor, gamma * cam)) for cam in cams]