From c0fc43eaac27ee56fd29b30c11c4873961bd7939 Mon Sep 17 00:00:00 2001 From: HubHop Date: Thu, 10 Mar 2022 11:18:48 +1100 Subject: [PATCH] add visualisation code for attention maps --- README.md | 42 ++- .../generate_attention_maps.py | 160 ++++++++++ .../code_for_lit_ti/pvt_full_msa.py | 297 ++++++++++++++++++ 3 files changed, 498 insertions(+), 1 deletion(-) create mode 100644 classification/code_for_lit_ti/generate_attention_maps.py create mode 100644 classification/code_for_lit_ti/pvt_full_msa.py diff --git a/README.md b/README.md index 46bb502..3f831b7 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,9 @@ If you use this code for a paper please cite: } ``` +## Updates + +- 10/03/2022: Add visualisation code for attention maps in Figure 3. ## Usage @@ -112,13 +115,50 @@ We provide a script for visualising the learned offsets by the proposed deformab conda activate lit cd classification/code_for_lit_ti -# visualize +# visualise python visualize_offset.py --model lit_ti --resume [path/to/lit_ti.pth] --vis_image visualization/demo.JPEG ``` The plots will be automatically saved under `visualization/`, with a folder named by the name of the example image. +## Attention Map Visualisation + +We provide our method for visualising the attention maps in Figure 3. To save your time, we also provide the pretrained model for PVT with standard MSA in all stages. + +| Name | Params (M) | FLOPs (G) | Top-1 Acc. (%) | Model | Log | +| ---------- | ---------- | --------- | -------------- | ------ | ---- | +| PVT w/ MSA | 20 | 8.4 | 80.9 | github | log | + +```bash +conda activate lit +cd classification/code_for_lit_ti + +# visualise +# by default, we save the results under 'classification/code_for_lit_ti/attn_results' +python generate_attention_maps.py --data-path [/path/to/imagenet] --resume [/path/to/pvt_full_msa.pth] +``` + +The resulting folder contains the following items, + +``` +. +├── attention_map +│   ├── stage-0 +│   │   ├── block0 +│   │   │   └── pixel-1260-block-0-head-0.png +│   │   ├── block1 +│   │   │   └── pixel-1260-block-1-head-0.png +│   │   └── block2 +│   │   └── pixel-1260-block-2-head-0.png +│   ├── stage-1 +│   ├── stage-2 +│   └── stage-3 +└── full_msa_eval_maps.npy +``` + +where `full_msa_eval_maps.npy` contains the saved attention maps in each block and each stage. The folder `attention_map` contains the visualisation results. + ## License diff --git a/classification/code_for_lit_ti/generate_attention_maps.py b/classification/code_for_lit_ti/generate_attention_maps.py new file mode 100644 index 0000000..682c2a4 --- /dev/null +++ b/classification/code_for_lit_ti/generate_attention_maps.py @@ -0,0 +1,160 @@ +import torch +import torch.nn as nn +import numpy as np +import utils +from timm.models import create_model +import pvt_full_msa +import os +import torch +import matplotlib.pyplot as plt +import matplotlib.patches as patches +import seaborn as sns +sns.set_theme() +cmap = sns.light_palette("#A20000", as_cmap=True) +from params import args +import torch.backends.cudnn as cudnn +from datasets import build_dataset + +@torch.no_grad() +def get_attention_data(data_loader, model, device, base_path): + criterion = torch.nn.CrossEntropyLoss() + + metric_logger = utils.MetricLogger(delimiter=" ") + header = 'Test:' + + # switch to evaluation mode + model.eval() + + attention_store = [] + samples = 0 + for images, target in metric_logger.log_every(data_loader, 10, header): + images = images.to(device, non_blocking=True) + target = target.to(device, non_blocking=True) + samples += images.size()[0] + # compute output + with torch.cuda.amp.autocast(): + output, attention_maps = model(images) + loss = criterion(output, target) + if len(attention_store) == 0: + for i, stage_maps in enumerate(attention_maps): + stage_attns = [] + for j, block_maps in enumerate(stage_maps): + # Simply use a summation to aggregate the attention probabilities from all batches, + # you can also try to use average or some other scaling methods + stage_attns.append(block_maps.sum(dim=0)) + attention_store.append(stage_attns) + else: + for i, stage_maps in enumerate(attention_maps): + for j, block_maps in enumerate(stage_maps): + attention_store[i][j] += block_maps.sum(dim=0) + + np_attns = [] + for i, stage_maps in enumerate(attention_store): + stage_attns = [] + for j, block_maps in enumerate(stage_maps): + block_maps /= samples + stage_attns.append(block_maps.numpy()) + np_attns.append(stage_attns) + np.save(os.path.join(base_path, 'full_msa_eval_maps.npy'), np.array(np_attns)) + break + +def visualize_attentions(base_path): + + save_path = os.path.join(base_path, 'attention_map') + attention_maps = np.load(os.path.join(base_path, 'full_msa_eval_maps.npy'), allow_pickle=True) + + linewidths = [1, 1, 2, 2] + # Remember that PVT has 4 stages + for stage_id, stage_attn_map in enumerate(attention_maps): + # each stage has several Transformer blocks + for block_id, block_attn_map in enumerate(stage_attn_map): + + block_attn_map = torch.from_numpy(block_attn_map) # size: num_head * seq_len * seq_len + + # PVT has the CLS token at the last stage, here we exclude it for better visualization. + if stage_id == 3: + test = block_attn_map[:, 1:, :] + block_attn_map = test[:, :, 1:] + + H, N, _ = block_attn_map.shape + width = int(N ** (1 / 2)) + + # iterate each self-attention head + for head_id in range(H): + head_atth_map = block_attn_map[head_id, ...] + map_save_dir = os.path.join(save_path, 'stage-'+str(stage_id), 'block'+str(block_id)) + + if not os.path.exists(map_save_dir): + os.makedirs(map_save_dir, exist_ok=True) + + for pixel_id in range(N): + # some random pixel indices, just want to make sure the visualized pixel is near the centre. + if stage_id == 0 and pixel_id != 1260: + continue + if stage_id == 1 and pixel_id != 294: + continue + if stage_id == 2 and pixel_id != 92: + continue + if stage_id == 3 and pixel_id != 17: + continue + + plt.clf() + f, ax = plt.subplots(1, 1, figsize=(4, 4)) + ax.set_aspect('equal') + + print(stage_id, block_id, head_id, pixel_id) + + pixel_attn_map = head_atth_map[pixel_id, ...].reshape(int(N ** (1 / 2)), int(N ** (1 / 2))).numpy() + + x = int(pixel_id % width) + y = int(pixel_id / width) + + # visualize the attention map with seaborn heatmap + ax = sns.heatmap(pixel_attn_map, cmap="OrRd", cbar=False, xticklabels=False, yticklabels=False, ax=ax) + patch = patches.Rectangle((x, y), 1, 1, linewidth=linewidths[stage_id], edgecolor='lime', facecolor='None') + ax.add_patch(patch) + image_name = 'pixel-{}-block-{}-head-{}.png'.format(pixel_id, block_id, head_id) + plt.savefig(os.path.join(map_save_dir, image_name), transparent=True) + + + +if __name__ == '__main__': + # You may change the path for saving the results. + save_path = 'attn_results' + if not os.path.exists(save_path): + os.makedirs(save_path, exist_ok=True) + + dataset_val, _ = build_dataset(is_train=False, args=args) + device = torch.device(args.device) + + # fix the seed for reproducibility + seed = args.seed + utils.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + cudnn.benchmark = True + + sampler_val = torch.utils.data.SequentialSampler(dataset_val) + data_loader_val = torch.utils.data.DataLoader( + dataset_val, sampler=sampler_val, + batch_size=100, + num_workers=args.num_workers, + pin_memory=args.pin_mem, + drop_last=False + ) + + model = create_model( + 'pvt_small_full_msa', + pretrained=False, + num_classes=1000, + drop_rate=args.drop, + drop_path_rate=args.drop_path, + drop_block_rate=None, + ) + model.to(device) + checkpoint = torch.load(args.resume) + model.load_state_dict(checkpoint) + get_attention_data(data_loader_val, model, device, save_path) + visualize_attentions(save_path) + + + diff --git a/classification/code_for_lit_ti/pvt_full_msa.py b/classification/code_for_lit_ti/pvt_full_msa.py new file mode 100644 index 0000000..f31baa5 --- /dev/null +++ b/classification/code_for_lit_ti/pvt_full_msa.py @@ -0,0 +1,297 @@ +import torch +import torch.nn as nn +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from timm.models.registry import register_model +from timm.models.vision_transformer import _cfg + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class ViTAttentionVis(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + + attn = attn.softmax(dim=-1) + # This is where I get the attention probabilities + attn_vis = attn.clone() + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x, attn_vis + + +class AttnBlockAnalyse(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, seq_len=0): + super().__init__() + self.norm1 = norm_layer(dim) + self.sr_ratio = sr_ratio + self.attn = ViTAttentionVis( + dim, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + output, attn = self.attn(self.norm1(x)) + x = x + self.drop_path(output) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x, attn + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + self.img_size = img_size + self.patch_size = patch_size + assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \ + f"img_size {img_size} should be divided by patch_size {patch_size}." + self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] + self.num_patches = self.H * self.W + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = nn.LayerNorm(embed_dim) + + def forward(self, x): + B, C, H, W = x.shape + + x = self.proj(x).flatten(2).transpose(1, 2) + x = self.norm(x) + H, W = H // self.patch_size[0], W // self.patch_size[1] + + return x, (H, W) + +class PVT_Full_MSA(nn.Module): + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], + num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, + depths=[3, 4, 6, 3], sr_ratios=[2, 2, 1, 1]): + super().__init__() + self.num_classes = num_classes + self.depths = depths + + # patch_embed + self.patch_embed1 = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, + embed_dim=embed_dims[0]) + self.patch_embed2 = PatchEmbed(img_size=img_size // 4, patch_size=2, in_chans=embed_dims[0], + embed_dim=embed_dims[1]) + self.patch_embed3 = PatchEmbed(img_size=img_size // 8, patch_size=2, in_chans=embed_dims[1], + embed_dim=embed_dims[2]) + self.patch_embed4 = PatchEmbed(img_size=img_size // 16, patch_size=2, in_chans=embed_dims[2], + embed_dim=embed_dims[3]) + + # pos_embed + self.pos_embed1 = nn.Parameter(torch.zeros(1, self.patch_embed1.num_patches, embed_dims[0])) + self.pos_drop1 = nn.Dropout(p=drop_rate) + self.pos_embed2 = nn.Parameter(torch.zeros(1, self.patch_embed2.num_patches, embed_dims[1])) + self.pos_drop2 = nn.Dropout(p=drop_rate) + self.pos_embed3 = nn.Parameter(torch.zeros(1, self.patch_embed3.num_patches, embed_dims[2])) + self.pos_drop3 = nn.Dropout(p=drop_rate) + self.pos_embed4 = nn.Parameter(torch.zeros(1, self.patch_embed4.num_patches + 1, embed_dims[3])) + self.pos_drop4 = nn.Dropout(p=drop_rate) + + # transformer encoder + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + cur = 0 + + image_width = img_size // 4 + self.block1 = nn.ModuleList([AttnBlockAnalyse( + dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[0], seq_len=int(image_width ** 2)) + for i in range(depths[0])]) + + cur += depths[0] + image_width = image_width // 2 + self.block2 = nn.ModuleList([AttnBlockAnalyse( + dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[1], seq_len=int(image_width ** 2)) + for i in range(depths[1])]) + + cur += depths[1] + image_width = image_width // 2 + self.block3 = nn.ModuleList([AttnBlockAnalyse( + dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[2], seq_len=int(image_width ** 2)) + for i in range(depths[2])]) + + cur += depths[2] + image_width = image_width // 2 + self.block4 = nn.ModuleList([AttnBlockAnalyse( + dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, + sr_ratio=sr_ratios[3], seq_len=int(image_width ** 2)) + for i in range(depths[3])]) + self.norm = norm_layer(embed_dims[3]) + + # cls_token + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims[3])) + + # classification head + self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() + + # self.lasso = 0. + + # init weights + trunc_normal_(self.pos_embed1, std=.02) + trunc_normal_(self.pos_embed2, std=.02) + trunc_normal_(self.pos_embed3, std=.02) + trunc_normal_(self.pos_embed4, std=.02) + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def reset_drop_path(self, drop_path_rate): + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] + cur = 0 + for i in range(self.depths[0]): + self.block1[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[0] + for i in range(self.depths[1]): + self.block2[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[1] + for i in range(self.depths[2]): + self.block3[i].drop_path.drop_prob = dpr[cur + i] + + cur += self.depths[2] + for i in range(self.depths[3]): + self.block4[i].drop_path.drop_prob = dpr[cur + i] + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'cls_token'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def get_attention_maps(self, x): + attentions = [] + B = x.shape[0] + # stage 1 + stage_attentions = [] + x, (H, W) = self.patch_embed1(x) + x = x + self.pos_embed1 + x = self.pos_drop1(x) + for blk in self.block1: + x, attn = blk(x) + stage_attentions.append(attn.detach().cpu()) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + attentions.append(stage_attentions) + + # stage 2 + stage_attentions = [] + x, (H, W) = self.patch_embed2(x) + x = x + self.pos_embed2 + x = self.pos_drop2(x) + for blk in self.block2: + x, attn = blk(x) + stage_attentions.append(attn.detach().cpu()) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + attentions.append(stage_attentions) + + # stage 3 + stage_attentions = [] + x, (H, W) = self.patch_embed3(x) + x = x + self.pos_embed3 + x = self.pos_drop3(x) + for blk in self.block3: + x, attn = blk(x) + stage_attentions.append(attn.detach().cpu()) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + attentions.append(stage_attentions) + + # stage 4 + stage_attentions = [] + x, (H, W) = self.patch_embed4(x) + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed4 + x = self.pos_drop4(x) + for blk in self.block4: + x, attn = blk(x) + stage_attentions.append(attn.detach().cpu()) + attentions.append(stage_attentions) + x = self.norm(x) + + return x[:, 0], attentions + + def forward(self, x): + x, attentions = self.get_attention_maps(x) + x = self.head(x) + + return x, attentions + +@register_model +def pvt_small_full_msa(pretrained=False, **kwargs): + model = PVT_Full_MSA( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[1, 1, 1, 1], **kwargs) + model.default_cfg = _cfg() + + return model \ No newline at end of file