Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] fix patch_embed and pos_embed mismatch error #685

Merged
merged 13 commits into from
Jul 19, 2021
37 changes: 17 additions & 20 deletions mmseg/models/backbones/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class VisionTransformer(BaseModule):
drop_path_rate (float): stochastic depth rate. Default 0.0
with_cls_token (bool): If concatenating class token into image tokens
as transformer input. Default: True.
output_cls_token (bool): Whether output the cls_token. Default: False.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN')
act_cfg (dict): The activation config for FFNs.
Expand All @@ -128,8 +129,6 @@ class VisionTransformer(BaseModule):
Default: False.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Default: False.
out_shape (str): Select the output format of feature information.
Default: NCHW.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Default: bicubic.
num_fcs (int): The number of fully-connected layers for FFNs.
Expand Down Expand Up @@ -160,11 +159,11 @@ def __init__(self,
attn_drop_rate=0.,
drop_path_rate=0.,
with_cls_token=True,
output_cls_token=False,
norm_cfg=dict(type='LN'),
act_cfg=dict(type='GELU'),
patch_norm=False,
final_norm=False,
out_shape='NCHW',
interpolate_mode='bicubic',
num_fcs=2,
norm_eval=False,
Expand All @@ -185,8 +184,9 @@ def __init__(self,

assert pretrain_style in ['timm', 'mmcls']

assert out_shape in ['NLC',
'NCHW'], 'output shape must be "NLC" or "NCHW".'
if output_cls_token:
assert with_cls_token is True, f'with_cls_token must be True if' \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is better to add this description to Docstring.

f'set output_cls_token to True, but got {with_cls_token}'

if isinstance(pretrained, str) or pretrained is None:
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
Expand All @@ -196,7 +196,6 @@ def __init__(self,

self.img_size = img_size
self.patch_size = patch_size
self.out_shape = out_shape
self.interpolate_mode = interpolate_mode
self.norm_eval = norm_eval
self.with_cp = with_cp
Expand All @@ -218,6 +217,7 @@ def __init__(self,
(img_size[1] // patch_size)

self.with_cls_token = with_cls_token
self.output_cls_token = output_cls_token
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, embed_dims))
Expand Down Expand Up @@ -253,7 +253,6 @@ def __init__(self,
batch_first=True))

self.final_norm = final_norm
self.out_shape = out_shape
if final_norm:
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, embed_dims, postfix=1)
Expand Down Expand Up @@ -317,14 +316,13 @@ def init_weights(self):
constant_init(m.bias, 0)
constant_init(m.weight, 1.0)

def _pos_embeding(self, img, patched_img, pos_embed):
def _pos_embeding(self, downsampled_img_size, patched_img, pos_embed):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _pos_embeding(self, downsampled_img_size, patched_img, pos_embed):
def _pos_embeding(self, x, hw_shape, pos_embed):

"""Positiong embeding method.

Resize the pos_embed, if the input image size doesn't match
the training size.
Args:
img (torch.Tensor): The inference image tensor, the shape
must be [B, C, H, W].
downsampled_img_size (tuple): The downsampled image resolution.
patched_img (torch.Tensor): The patched image, it should be
shape of [B, L1, C].
pos_embed (torch.Tensor): The pos_embed weighs, it should be
Expand All @@ -344,7 +342,7 @@ def _pos_embeding(self, img, patched_img, pos_embed):
raise ValueError(
'Unexpected shape of pos_embed, got {}.'.format(
pos_embed.shape))
pos_embed = self.resize_pos_embed(pos_embed, img.shape[2:],
pos_embed = self.resize_pos_embed(pos_embed, downsampled_img_size,
(pos_h, pos_w), self.patch_size,
self.interpolate_mode)
return self.drop_after_pos(patched_img + pos_embed)
Expand All @@ -371,7 +369,7 @@ def resize_pos_embed(pos_embed, input_shpae, pos_shape, patch_size, mode):
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
pos_embed_weight = F.interpolate(
pos_embed_weight,
size=[input_h // patch_size, input_w // patch_size],
size=[input_h, input_w],
Copy link
Collaborator

@Junjun2016 Junjun2016 Jul 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
size=[input_h, input_w],
size=input_shpae,

so input_h, input_w = input_shpae is redundant.

align_corners=False,
mode=mode)
cls_token_weight = cls_token_weight.unsqueeze(1)
Expand All @@ -382,12 +380,12 @@ def resize_pos_embed(pos_embed, input_shpae, pos_shape, patch_size, mode):
def forward(self, inputs):
B = inputs.shape[0]

x = self.patch_embed(inputs)

x, H, W = self.patch_embed(
inputs), self.patch_embed.DH, self.patch_embed.DW
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may use hw_shape instead, to be consistent with Swin Transformer.

# stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = self._pos_embeding(inputs, x, self.pos_embed)
x = self._pos_embeding((H, W), x, self.pos_embed)

if not self.with_cls_token:
# Remove class token for transformer encoder input
Expand All @@ -405,11 +403,10 @@ def forward(self, inputs):
out = x[:, 1:]
else:
out = x
if self.out_shape == 'NCHW':
B, _, C = out.shape
out = out.reshape(B, inputs.shape[2] // self.patch_size,
inputs.shape[3] // self.patch_size,
C).permute(0, 3, 1, 2)
B, _, C = out.shape
out = out.reshape(B, H, W, C).permute(0, 3, 1, 2)
if self.output_cls_token:
out = [out, x[:, 0]]
outs.append(out)

return tuple(outs)
Expand Down
22 changes: 14 additions & 8 deletions tests/test_models/test_backbones/test_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def test_vit_backbone():
VisionTransformer(pretrained=123)

with pytest.raises(AssertionError):
# out_shape must be 'NLC' or 'NCHW;'
VisionTransformer(out_shape='NCL')
# with_cls_token must be True when output_cls_token == True
VisionTransformer(with_cls_token=False, output_cls_token=True)

# Test img_size isinstance tuple
imgs = torch.randn(1, 3, 224, 224)
Expand Down Expand Up @@ -88,6 +88,11 @@ def test_vit_backbone():
feat = model(imgs)
assert feat[-1].shape == (1, 768, 7, 14)

# Test irregular input image
imgs = torch.randn(1, 3, 234, 345)
feat = model(imgs)
assert feat[-1].shape == (1, 768, 15, 22)

# Test with_cp=True
model = VisionTransformer(with_cp=True)
imgs = torch.randn(1, 3, 224, 224)
Expand All @@ -100,12 +105,6 @@ def test_vit_backbone():
feat = model(imgs)
assert feat[-1].shape == (1, 768, 14, 14)

# Test out_shape == 'NLC'
model = VisionTransformer(out_shape='NLC')
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat[-1].shape == (1, 196, 768)

# Test final norm
model = VisionTransformer(final_norm=True)
imgs = torch.randn(1, 3, 224, 224)
Expand All @@ -117,3 +116,10 @@ def test_vit_backbone():
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat[-1].shape == (1, 768, 14, 14)

# Test output_cls_token
model = VisionTransformer(with_cls_token=True, output_cls_token=True)
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert feat[0][0].shape == (1, 768, 14, 14)
assert feat[0][1].shape == (1, 768)