Skip to content

Commit

Permalink
* Store intermediate token features and impose no processes on them;
Browse files Browse the repository at this point in the history
* Remove class token and reshape entire token feature from NLC to NCHW;
  • Loading branch information
sennnnn committed Apr 29, 2021
1 parent 2f580e9 commit e1d59cd
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 19 deletions.
39 changes: 24 additions & 15 deletions mmseg/models/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ class VisionTransformer(nn.Module):
Image Recognition at Scale` - https://arxiv.org/abs/2010.11929
Args:
img_size (tuple): input image size. Default: (224, 224.
img_size (tuple): input image size. Default: (224, 224).
patch_size (int, tuple): patch size. Default: 16.
in_channels (int): number of input channels. Default: 3.
embed_dim (int): embedding dimension. Default: 768.
Expand Down Expand Up @@ -270,7 +270,10 @@ def __init__(self,
torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)

self.blocks = nn.Sequential(*[
self.num_stages = depth
self.out_indices = tuple(range(self.num_stages))

self.blocks = nn.ModuleList([
Block(
dim=embed_dim,
num_heads=num_heads,
Expand All @@ -281,7 +284,7 @@ def __init__(self,
attn_drop=attn_drop_rate,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
with_cp=with_cp) for i in range(depth)
with_cp=with_cp) for i in range(self.num_stages)
])

self.interpolate_mode = interpolate_mode
Expand Down Expand Up @@ -407,18 +410,24 @@ def forward(self, inputs):
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = self._pos_embeding(inputs, x, self.pos_embed)
x = self.blocks(x)

if self.final_norm:
x = self.norm(x)

# Remove class token
x = x[:, 1:]
B, _, C = x.shape
x = x.reshape(B, inputs.shape[2] // self.patch_size,
inputs.shape[3] // self.patch_size,
C).permute(0, 3, 1, 2)
return [x]

outs = []
block_len = len(self.blocks)
for i, blk in enumerate(self.blocks):
x = blk(x)
if i == block_len - 1:
if self.final_norm:
x = self.norm(x)
if i in self.out_indices:
# Remove class token and reshape token for decoder head
out = x[:, 1:]
B, _, C = out.shape
out = out.reshape(B, inputs.shape[2] // self.patch_size,
inputs.shape[3] // self.patch_size,
C).permute(0, 3, 1, 2)
outs.append(out)

return tuple(outs)

def train(self, mode=True):
super(VisionTransformer, self).train(mode)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_models/test_backbones/test_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,19 @@ def test_vit_backbone():
# Test large size input image
imgs = torch.randn(1, 3, 256, 256)
feat = model(imgs)
assert feat[0].shape == (1, 768, 16, 16)
assert feat[-1].shape == (1, 768, 16, 16)

# Test small size input image
imgs = torch.randn(1, 3, 32, 32)
feat = model(imgs)
assert feat[0].shape == (1, 768, 2, 2)
assert feat[-1].shape == (1, 768, 2, 2)

imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat[0].shape == (1, 768, 14, 14)
assert feat[-1].shape == (1, 768, 14, 14)

# Test with_cp=True
model = VisionTransformer(with_cp=True)
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat[0].shape == (1, 768, 14, 14)
assert feat[-1].shape == (1, 768, 14, 14)

0 comments on commit e1d59cd

Please sign in to comment.