-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Fix] fix patch_embed and pos_embed mismatch error #685
Changes from 7 commits
f7e2faa
dcbc3c7
99c5962
2738dd7
822a115
f1e97df
d59d2e3
efbd67e
523c440
a4f4d5c
0c2e560
bab4c03
9a4e1e8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -120,6 +120,7 @@ class VisionTransformer(BaseModule): | |||||
drop_path_rate (float): stochastic depth rate. Default 0.0 | ||||||
with_cls_token (bool): If concatenating class token into image tokens | ||||||
as transformer input. Default: True. | ||||||
output_cls_token (bool): Whether output the cls_token. Default: False. | ||||||
norm_cfg (dict): Config dict for normalization layer. | ||||||
Default: dict(type='LN') | ||||||
act_cfg (dict): The activation config for FFNs. | ||||||
|
@@ -128,8 +129,6 @@ class VisionTransformer(BaseModule): | |||||
Default: False. | ||||||
final_norm (bool): Whether to add a additional layer to normalize | ||||||
final feature map. Default: False. | ||||||
out_shape (str): Select the output format of feature information. | ||||||
Default: NCHW. | ||||||
interpolate_mode (str): Select the interpolate mode for position | ||||||
embeding vector resize. Default: bicubic. | ||||||
num_fcs (int): The number of fully-connected layers for FFNs. | ||||||
|
@@ -160,11 +159,11 @@ def __init__(self, | |||||
attn_drop_rate=0., | ||||||
drop_path_rate=0., | ||||||
with_cls_token=True, | ||||||
output_cls_token=False, | ||||||
norm_cfg=dict(type='LN'), | ||||||
act_cfg=dict(type='GELU'), | ||||||
patch_norm=False, | ||||||
final_norm=False, | ||||||
out_shape='NCHW', | ||||||
interpolate_mode='bicubic', | ||||||
num_fcs=2, | ||||||
norm_eval=False, | ||||||
|
@@ -185,8 +184,9 @@ def __init__(self, | |||||
|
||||||
assert pretrain_style in ['timm', 'mmcls'] | ||||||
|
||||||
assert out_shape in ['NLC', | ||||||
'NCHW'], 'output shape must be "NLC" or "NCHW".' | ||||||
if output_cls_token: | ||||||
assert with_cls_token is True, f'with_cls_token must be True if' \ | ||||||
f'set output_cls_token to True, but got {with_cls_token}' | ||||||
|
||||||
if isinstance(pretrained, str) or pretrained is None: | ||||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, ' | ||||||
|
@@ -196,7 +196,6 @@ def __init__(self, | |||||
|
||||||
self.img_size = img_size | ||||||
self.patch_size = patch_size | ||||||
self.out_shape = out_shape | ||||||
self.interpolate_mode = interpolate_mode | ||||||
self.norm_eval = norm_eval | ||||||
self.with_cp = with_cp | ||||||
|
@@ -218,6 +217,7 @@ def __init__(self, | |||||
(img_size[1] // patch_size) | ||||||
|
||||||
self.with_cls_token = with_cls_token | ||||||
self.output_cls_token = output_cls_token | ||||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) | ||||||
self.pos_embed = nn.Parameter( | ||||||
torch.zeros(1, num_patches + 1, embed_dims)) | ||||||
|
@@ -253,7 +253,6 @@ def __init__(self, | |||||
batch_first=True)) | ||||||
|
||||||
self.final_norm = final_norm | ||||||
self.out_shape = out_shape | ||||||
if final_norm: | ||||||
self.norm1_name, norm1 = build_norm_layer( | ||||||
norm_cfg, embed_dims, postfix=1) | ||||||
|
@@ -290,8 +289,9 @@ def init_weights(self): | |||||
pos_size = int( | ||||||
math.sqrt(state_dict['pos_embed'].shape[1] - 1)) | ||||||
state_dict['pos_embed'] = self.resize_pos_embed( | ||||||
state_dict['pos_embed'], (h, w), (pos_size, pos_size), | ||||||
self.patch_size, self.interpolate_mode) | ||||||
state_dict['pos_embed'], | ||||||
(h // self.patch_size, w // self.patch_size), | ||||||
(pos_size, pos_size), self.interpolate_mode) | ||||||
|
||||||
self.load_state_dict(state_dict, False) | ||||||
|
||||||
|
@@ -317,16 +317,15 @@ def init_weights(self): | |||||
constant_init(m.bias, 0) | ||||||
constant_init(m.weight, 1.0) | ||||||
|
||||||
def _pos_embeding(self, img, patched_img, pos_embed): | ||||||
def _pos_embeding(self, patched_img, hw_shape, pos_embed): | ||||||
"""Positiong embeding method. | ||||||
|
||||||
Resize the pos_embed, if the input image size doesn't match | ||||||
the training size. | ||||||
Args: | ||||||
img (torch.Tensor): The inference image tensor, the shape | ||||||
must be [B, C, H, W]. | ||||||
patched_img (torch.Tensor): The patched image, it should be | ||||||
shape of [B, L1, C]. | ||||||
hw_shape (tuple): The downsampled image resolution. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe I am not sure. @xvjiarui |
||||||
pos_embed (torch.Tensor): The pos_embed weighs, it should be | ||||||
shape of [B, L2, c]. | ||||||
Return: | ||||||
|
@@ -344,21 +343,21 @@ def _pos_embeding(self, img, patched_img, pos_embed): | |||||
raise ValueError( | ||||||
'Unexpected shape of pos_embed, got {}.'.format( | ||||||
pos_embed.shape)) | ||||||
pos_embed = self.resize_pos_embed(pos_embed, img.shape[2:], | ||||||
(pos_h, pos_w), self.patch_size, | ||||||
pos_embed = self.resize_pos_embed(pos_embed, hw_shape, | ||||||
(pos_h, pos_w), | ||||||
self.interpolate_mode) | ||||||
return self.drop_after_pos(patched_img + pos_embed) | ||||||
|
||||||
@staticmethod | ||||||
def resize_pos_embed(pos_embed, input_shpae, pos_shape, patch_size, mode): | ||||||
def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode): | ||||||
"""Resize pos_embed weights. | ||||||
|
||||||
Resize pos_embed using bicubic interpolate method. | ||||||
Args: | ||||||
pos_embed (torch.Tensor): pos_embed weights. | ||||||
input_shpae (tuple): Tuple for (input_h, intput_w). | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe I am not sure. @xvjiarui |
||||||
pos_shape (tuple): Tuple for (pos_h, pos_w). | ||||||
patch_size (int): Patch size. | ||||||
mode (str): Algorithm used for upsampling. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be better to descript the |
||||||
Return: | ||||||
torch.Tensor: The resized pos_embed of shape [B, L_new, C] | ||||||
""" | ||||||
|
@@ -371,7 +370,7 @@ def resize_pos_embed(pos_embed, input_shpae, pos_shape, patch_size, mode): | |||||
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2) | ||||||
pos_embed_weight = F.interpolate( | ||||||
pos_embed_weight, | ||||||
size=[input_h // patch_size, input_w // patch_size], | ||||||
size=[input_h, input_w], | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
so |
||||||
align_corners=False, | ||||||
mode=mode) | ||||||
cls_token_weight = cls_token_weight.unsqueeze(1) | ||||||
|
@@ -382,12 +381,12 @@ def resize_pos_embed(pos_embed, input_shpae, pos_shape, patch_size, mode): | |||||
def forward(self, inputs): | ||||||
B = inputs.shape[0] | ||||||
|
||||||
x = self.patch_embed(inputs) | ||||||
|
||||||
x, hw_shape = self.patch_embed(inputs), (self.patch_embed.DH, | ||||||
self.patch_embed.DW) | ||||||
# stole cls_tokens impl from Phil Wang, thanks | ||||||
cls_tokens = self.cls_token.expand(B, -1, -1) | ||||||
x = torch.cat((cls_tokens, x), dim=1) | ||||||
x = self._pos_embeding(inputs, x, self.pos_embed) | ||||||
x = self._pos_embeding(x, hw_shape, self.pos_embed) | ||||||
|
||||||
if not self.with_cls_token: | ||||||
# Remove class token for transformer encoder input | ||||||
|
@@ -405,11 +404,11 @@ def forward(self, inputs): | |||||
out = x[:, 1:] | ||||||
else: | ||||||
out = x | ||||||
if self.out_shape == 'NCHW': | ||||||
B, _, C = out.shape | ||||||
out = out.reshape(B, inputs.shape[2] // self.patch_size, | ||||||
inputs.shape[3] // self.patch_size, | ||||||
C).permute(0, 3, 1, 2) | ||||||
B, _, C = out.shape | ||||||
out = out.reshape(B, hw_shape[0], hw_shape[1], | ||||||
C).permute(0, 3, 1, 2) | ||||||
if self.output_cls_token: | ||||||
out = [out, x[:, 0]] | ||||||
outs.append(out) | ||||||
|
||||||
return tuple(outs) | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is better to add this description to Docstring.