Skip to content

Commit

Permalink
* Store intermediate token features and impose no processes on them;
Browse files Browse the repository at this point in the history
* Remove class token and reshape entire token feature from NLC to NCHW;
  • Loading branch information
sennnnn committed Apr 29, 2021
1 parent 2f580e9 commit 49722de
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 18 deletions.
37 changes: 23 additions & 14 deletions mmseg/models/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,10 @@ def __init__(self,
torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)

self.blocks = nn.Sequential(*[
self.num_stages = depth
self.out_indices = tuple(range(self.num_stages))

self.blocks = nn.ModuleList([
Block(
dim=embed_dim,
num_heads=num_heads,
Expand All @@ -281,7 +284,7 @@ def __init__(self,
attn_drop=attn_drop_rate,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
with_cp=with_cp) for i in range(depth)
with_cp=with_cp) for i in range(self.num_stages)
])

self.interpolate_mode = interpolate_mode
Expand Down Expand Up @@ -407,18 +410,24 @@ def forward(self, inputs):
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = self._pos_embeding(inputs, x, self.pos_embed)
x = self.blocks(x)

if self.final_norm:
x = self.norm(x)

# Remove class token
x = x[:, 1:]
B, _, C = x.shape
x = x.reshape(B, inputs.shape[2] // self.patch_size,
inputs.shape[3] // self.patch_size,
C).permute(0, 3, 1, 2)
return [x]

outs = []
block_len = len(self.blocks)
for i, blk in enumerate(self.blocks):
x = blk(x)
if i == block_len - 1:
if self.final_norm:
x = self.norm(x)
# Remove class token and reshape token only in output stage
x = x[:, 1:]
B, _, C = x.shape
x = x.reshape(B, inputs.shape[2] // self.patch_size,
inputs.shape[3] // self.patch_size,
C).permute(0, 3, 1, 2)
if i in self.out_indices:
outs.append(x)

return tuple(outs)

def train(self, mode=True):
super(VisionTransformer, self).train(mode)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_models/test_backbones/test_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,19 @@ def test_vit_backbone():
# Test large size input image
imgs = torch.randn(1, 3, 256, 256)
feat = model(imgs)
assert feat[0].shape == (1, 768, 16, 16)
assert feat[-1].shape == (1, 768, 16, 16)

# Test small size input image
imgs = torch.randn(1, 3, 32, 32)
feat = model(imgs)
assert feat[0].shape == (1, 768, 2, 2)
assert feat[-1].shape == (1, 768, 2, 2)

imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat[0].shape == (1, 768, 14, 14)
assert feat[-1].shape == (1, 768, 14, 14)

# Test with_cp=True
model = VisionTransformer(with_cp=True)
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat[0].shape == (1, 768, 14, 14)
assert feat[-1].shape == (1, 768, 14, 14)

0 comments on commit 49722de

Please sign in to comment.