Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option for output shape of ViT #530

Merged
merged 4 commits into from
May 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions mmseg/models/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ class VisionTransformer(nn.Module):
and its variants only. Default: False.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Default: False.
out_reshape (str): Select the output format of feature information.
Default: NCHW.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Default: bicubic.
with_cls_token (bool): If concatenating class token into image tokens
Expand Down Expand Up @@ -261,6 +263,7 @@ def __init__(self,
act_cfg=dict(type='GELU'),
norm_eval=False,
final_norm=False,
out_shape='NCHW',
with_cls_token=True,
interpolate_mode='bicubic',
with_cp=False):
Expand Down Expand Up @@ -303,6 +306,11 @@ def __init__(self,
with_cp=with_cp) for i in range(depth)
])

assert out_shape in ['NLC',
'NCHW'], 'output shape must be "NLC" or "NCHW".'

self.out_shape = out_shape

self.interpolate_mode = interpolate_mode
self.final_norm = final_norm
if final_norm:
Expand Down Expand Up @@ -443,10 +451,11 @@ def forward(self, inputs):
out = x[:, 1:]
else:
out = x
B, _, C = out.shape
out = out.reshape(B, inputs.shape[2] // self.patch_size,
inputs.shape[3] // self.patch_size,
C).permute(0, 3, 1, 2)
if self.out_shape == 'NCHW':
B, _, C = out.shape
out = out.reshape(B, inputs.shape[2] // self.patch_size,
inputs.shape[3] // self.patch_size,
C).permute(0, 3, 1, 2)
outs.append(out)

return tuple(outs)
Expand Down
10 changes: 10 additions & 0 deletions tests/test_models/test_backbones/test_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ def test_vit_backbone():
model = VisionTransformer()
model(x)

with pytest.raises(AssertionError):
# out_shape must be 'NLC' or 'NCHW;'
VisionTransformer(out_shape='NCL')

# Test img_size isinstance int
imgs = torch.randn(1, 3, 224, 224)
model = VisionTransformer(img_size=224)
Expand Down Expand Up @@ -72,3 +76,9 @@ def test_vit_backbone():
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat[-1].shape == (1, 768, 14, 14)

# Test final reshape arg
imgs = torch.randn(1, 3, 224, 224)
model = VisionTransformer(out_shape='NLC')
feat = model(imgs)
assert feat[-1].shape == (1, 196, 768)