-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
Copy pathgradcam_utils.py
243 lines (206 loc) · 9.42 KB
/
gradcam_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
class GradCAM:
"""GradCAM class helps create visualization results.
Visualization results are blended by heatmaps and input images.
This class is modified from
https://github.com/facebookresearch/SlowFast/blob/master/slowfast/visualization/gradcam_utils.py # noqa
For more information about GradCAM, please visit:
https://arxiv.org/pdf/1610.02391.pdf
Args:
model (nn.Module): the recognizer model to be used.
target_layer_name (str): name of convolutional layer to
be used to get gradients and feature maps from for creating
localization maps.
colormap (str): matplotlib colormap used to create
heatmap. Defaults to 'viridis'. For more information, please visit
https://matplotlib.org/3.3.0/tutorials/colors/colormaps.html
"""
def __init__(self,
model: nn.Module,
target_layer_name: str,
colormap: str = 'viridis') -> None:
from ..models.recognizers import Recognizer2D, Recognizer3D
if isinstance(model, Recognizer2D):
self.is_recognizer2d = True
elif isinstance(model, Recognizer3D):
self.is_recognizer2d = False
else:
raise ValueError(
'GradCAM utils only support Recognizer2D & Recognizer3D.')
self.model = model
self.model.eval()
self.target_gradients = None
self.target_activations = None
import matplotlib.pyplot as plt
self.colormap = plt.get_cmap(colormap)
self._register_hooks(target_layer_name)
def _register_hooks(self, layer_name: str) -> None:
"""Register forward and backward hook to a layer, given layer_name, to
obtain gradients and activations.
Args:
layer_name (str): name of the layer.
"""
def get_gradients(module, grad_input, grad_output):
self.target_gradients = grad_output[0].detach()
def get_activations(module, input, output):
self.target_activations = output.clone().detach()
layer_ls = layer_name.split('/')
prev_module = self.model
for layer in layer_ls:
prev_module = prev_module._modules[layer]
target_layer = prev_module
target_layer.register_forward_hook(get_activations)
target_layer.register_backward_hook(get_gradients)
def _calculate_localization_map(self,
data: dict,
use_labels: bool,
delta=1e-20) -> tuple:
"""Calculate localization map for all inputs with Grad-CAM.
Args:
data (dict): model inputs, generated by test pipeline,
use_labels (bool): Whether to use given labels to generate
localization map.
delta (float): used in localization map normalization,
must be small enough. Please make sure
`localization_map_max - localization_map_min >> delta`
Returns:
localization_map (torch.Tensor): the localization map for
input imgs.
preds (torch.Tensor): Model predictions with shape
(batch_size, num_classes).
"""
inputs = data['inputs']
# use score before softmax
self.model.cls_head.average_clips = 'score'
# model forward & backward
results = self.model.test_step(data)
preds = [result.pred_score for result in results]
preds = torch.stack(preds)
if use_labels:
labels = [result.gt_label for result in results]
labels = torch.stack(labels)
score = torch.gather(preds, dim=1, index=labels)
else:
score = torch.max(preds, dim=-1)[0]
self.model.zero_grad()
score = torch.sum(score)
score.backward()
imgs = torch.stack(inputs)
if self.is_recognizer2d:
# [batch_size, num_segments, 3, H, W]
b, t, _, h, w = imgs.size()
else:
# [batch_size, num_crops*num_clips, 3, clip_len, H, W]
b1, b2, _, t, h, w = imgs.size()
b = b1 * b2
gradients = self.target_gradients
activations = self.target_activations
if self.is_recognizer2d:
# [B*Tg, C', H', W']
b_tg, c, _, _ = gradients.size()
tg = b_tg // b
else:
# source shape: [B, C', Tg, H', W']
_, c, tg, _, _ = gradients.size()
# target shape: [B, Tg, C', H', W']
gradients = gradients.permute(0, 2, 1, 3, 4)
activations = activations.permute(0, 2, 1, 3, 4)
# calculate & resize to [B, 1, T, H, W]
weights = torch.mean(gradients.view(b, tg, c, -1), dim=3)
weights = weights.view(b, tg, c, 1, 1)
activations = activations.view([b, tg, c] +
list(activations.size()[-2:]))
localization_map = torch.sum(
weights * activations, dim=2, keepdim=True)
localization_map = F.relu(localization_map)
localization_map = localization_map.permute(0, 2, 1, 3, 4)
localization_map = F.interpolate(
localization_map,
size=(t, h, w),
mode='trilinear',
align_corners=False)
# Normalize the localization map.
localization_map_min, localization_map_max = (
torch.min(localization_map.view(b, -1), dim=-1, keepdim=True)[0],
torch.max(localization_map.view(b, -1), dim=-1, keepdim=True)[0])
localization_map_min = torch.reshape(
localization_map_min, shape=(b, 1, 1, 1, 1))
localization_map_max = torch.reshape(
localization_map_max, shape=(b, 1, 1, 1, 1))
localization_map = (localization_map - localization_map_min) / (
localization_map_max - localization_map_min + delta)
localization_map = localization_map.data
return localization_map.squeeze(dim=1), preds
def _alpha_blending(self, localization_map: torch.Tensor,
input_imgs: torch.Tensor,
alpha: float) -> torch.Tensor:
"""Blend heatmaps and model input images and get visulization results.
Args:
localization_map (torch.Tensor): localization map for all inputs,
generated with Grad-CAM.
input_imgs (torch.Tensor): model inputs, raw images.
alpha (float): transparency level of the heatmap,
in the range [0, 1].
Returns:
torch.Tensor: blending results for localization map and input
images, with shape [B, T, H, W, 3] and pixel values in
RGB order within range [0, 1].
"""
# localization_map shape [B, T, H, W]
localization_map = localization_map.cpu()
# heatmap shape [B, T, H, W, 3] in RGB order
heatmap = self.colormap(localization_map.detach().numpy())
heatmap = heatmap[..., :3]
heatmap = torch.from_numpy(heatmap)
input_imgs = torch.stack(input_imgs)
# Permute input imgs to [B, T, H, W, 3], like heatmap
if self.is_recognizer2d:
# Recognizer2D input (B, T, C, H, W)
curr_inp = input_imgs.permute(0, 1, 3, 4, 2)
else:
# Recognizer3D input (B', num_clips*num_crops, C, T, H, W)
# B = B' * num_clips * num_crops
curr_inp = input_imgs.view([-1] + list(input_imgs.size()[2:]))
curr_inp = curr_inp.permute(0, 2, 3, 4, 1)
# renormalize input imgs to [0, 1]
curr_inp = curr_inp.cpu().float()
curr_inp /= 255.
# alpha blending
blended_imgs = alpha * heatmap + (1 - alpha) * curr_inp
return blended_imgs
def __call__(self,
data: dict,
use_labels: bool = False,
alpha: float = 0.5) -> tuple:
"""Visualize the localization maps on their corresponding inputs as
heatmap, using Grad-CAM.
Generate visualization results for **ALL CROPS**.
For example, for I3D model, if `clip_len=32, num_clips=10` and
use `ThreeCrop` in test pipeline, then for every model inputs,
there are 960(32*10*3) images generated.
Args:
data (dict): model inputs, generated by test pipeline.
use_labels (bool): Whether to use given labels to generate
localization map.
alpha (float): transparency level of the heatmap,
in the range [0, 1].
Returns:
blended_imgs (torch.Tensor): Visualization results, blended by
localization maps and model inputs.
preds (torch.Tensor): Model predictions for inputs.
"""
# localization_map shape [B, T, H, W]
# preds shape [batch_size, num_classes]
localization_map, preds = self._calculate_localization_map(
data, use_labels=use_labels)
# blended_imgs shape [B, T, H, W, 3]
blended_imgs = self._alpha_blending(localization_map, data['inputs'],
alpha)
# blended_imgs shape [B, T, H, W, 3]
# preds shape [batch_size, num_classes]
# Recognizer2D: B = batch_size, T = num_segments
# Recognizer3D: B = batch_size * num_crops * num_clips, T = clip_len
return blended_imgs, preds