-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Fix] fix patch_embed and pos_embed mismatch error #685
Changes from 4 commits
f7e2faa
dcbc3c7
99c5962
2738dd7
822a115
f1e97df
d59d2e3
efbd67e
523c440
a4f4d5c
0c2e560
bab4c03
9a4e1e8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -120,6 +120,7 @@ class VisionTransformer(BaseModule): | |||||
drop_path_rate (float): stochastic depth rate. Default 0.0 | ||||||
with_cls_token (bool): If concatenating class token into image tokens | ||||||
as transformer input. Default: True. | ||||||
output_cls_token (bool): Whether output the cls_token. Default: False. | ||||||
norm_cfg (dict): Config dict for normalization layer. | ||||||
Default: dict(type='LN') | ||||||
act_cfg (dict): The activation config for FFNs. | ||||||
|
@@ -128,8 +129,6 @@ class VisionTransformer(BaseModule): | |||||
Default: False. | ||||||
final_norm (bool): Whether to add a additional layer to normalize | ||||||
final feature map. Default: False. | ||||||
out_shape (str): Select the output format of feature information. | ||||||
Default: NCHW. | ||||||
interpolate_mode (str): Select the interpolate mode for position | ||||||
embeding vector resize. Default: bicubic. | ||||||
num_fcs (int): The number of fully-connected layers for FFNs. | ||||||
|
@@ -160,11 +159,11 @@ def __init__(self, | |||||
attn_drop_rate=0., | ||||||
drop_path_rate=0., | ||||||
with_cls_token=True, | ||||||
output_cls_token=False, | ||||||
norm_cfg=dict(type='LN'), | ||||||
act_cfg=dict(type='GELU'), | ||||||
patch_norm=False, | ||||||
final_norm=False, | ||||||
out_shape='NCHW', | ||||||
interpolate_mode='bicubic', | ||||||
num_fcs=2, | ||||||
norm_eval=False, | ||||||
|
@@ -185,8 +184,9 @@ def __init__(self, | |||||
|
||||||
assert pretrain_style in ['timm', 'mmcls'] | ||||||
|
||||||
assert out_shape in ['NLC', | ||||||
'NCHW'], 'output shape must be "NLC" or "NCHW".' | ||||||
if output_cls_token: | ||||||
assert with_cls_token is True, f'with_cls_token must be True if' \ | ||||||
f'set output_cls_token to True, but got {with_cls_token}' | ||||||
|
||||||
if isinstance(pretrained, str) or pretrained is None: | ||||||
warnings.warn('DeprecationWarning: pretrained is a deprecated, ' | ||||||
|
@@ -196,7 +196,6 @@ def __init__(self, | |||||
|
||||||
self.img_size = img_size | ||||||
self.patch_size = patch_size | ||||||
self.out_shape = out_shape | ||||||
self.interpolate_mode = interpolate_mode | ||||||
self.norm_eval = norm_eval | ||||||
self.with_cp = with_cp | ||||||
|
@@ -218,6 +217,7 @@ def __init__(self, | |||||
(img_size[1] // patch_size) | ||||||
|
||||||
self.with_cls_token = with_cls_token | ||||||
self.output_cls_token = output_cls_token | ||||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) | ||||||
self.pos_embed = nn.Parameter( | ||||||
torch.zeros(1, num_patches + 1, embed_dims)) | ||||||
|
@@ -253,7 +253,6 @@ def __init__(self, | |||||
batch_first=True)) | ||||||
|
||||||
self.final_norm = final_norm | ||||||
self.out_shape = out_shape | ||||||
if final_norm: | ||||||
self.norm1_name, norm1 = build_norm_layer( | ||||||
norm_cfg, embed_dims, postfix=1) | ||||||
|
@@ -317,14 +316,13 @@ def init_weights(self): | |||||
constant_init(m.bias, 0) | ||||||
constant_init(m.weight, 1.0) | ||||||
|
||||||
def _pos_embeding(self, img, patched_img, pos_embed): | ||||||
def _pos_embeding(self, downsampled_img_size, patched_img, pos_embed): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
"""Positiong embeding method. | ||||||
|
||||||
Resize the pos_embed, if the input image size doesn't match | ||||||
the training size. | ||||||
Args: | ||||||
img (torch.Tensor): The inference image tensor, the shape | ||||||
must be [B, C, H, W]. | ||||||
downsampled_img_size (tuple): The downsampled image resolution. | ||||||
patched_img (torch.Tensor): The patched image, it should be | ||||||
shape of [B, L1, C]. | ||||||
pos_embed (torch.Tensor): The pos_embed weighs, it should be | ||||||
|
@@ -344,7 +342,7 @@ def _pos_embeding(self, img, patched_img, pos_embed): | |||||
raise ValueError( | ||||||
'Unexpected shape of pos_embed, got {}.'.format( | ||||||
pos_embed.shape)) | ||||||
pos_embed = self.resize_pos_embed(pos_embed, img.shape[2:], | ||||||
pos_embed = self.resize_pos_embed(pos_embed, downsampled_img_size, | ||||||
(pos_h, pos_w), self.patch_size, | ||||||
self.interpolate_mode) | ||||||
return self.drop_after_pos(patched_img + pos_embed) | ||||||
|
@@ -371,7 +369,7 @@ def resize_pos_embed(pos_embed, input_shpae, pos_shape, patch_size, mode): | |||||
1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2) | ||||||
pos_embed_weight = F.interpolate( | ||||||
pos_embed_weight, | ||||||
size=[input_h // patch_size, input_w // patch_size], | ||||||
size=[input_h, input_w], | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
so |
||||||
align_corners=False, | ||||||
mode=mode) | ||||||
cls_token_weight = cls_token_weight.unsqueeze(1) | ||||||
|
@@ -382,12 +380,12 @@ def resize_pos_embed(pos_embed, input_shpae, pos_shape, patch_size, mode): | |||||
def forward(self, inputs): | ||||||
B = inputs.shape[0] | ||||||
|
||||||
x = self.patch_embed(inputs) | ||||||
|
||||||
x, H, W = self.patch_embed( | ||||||
inputs), self.patch_embed.DH, self.patch_embed.DW | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We may use |
||||||
# stole cls_tokens impl from Phil Wang, thanks | ||||||
cls_tokens = self.cls_token.expand(B, -1, -1) | ||||||
x = torch.cat((cls_tokens, x), dim=1) | ||||||
x = self._pos_embeding(inputs, x, self.pos_embed) | ||||||
x = self._pos_embeding((H, W), x, self.pos_embed) | ||||||
|
||||||
if not self.with_cls_token: | ||||||
# Remove class token for transformer encoder input | ||||||
|
@@ -405,11 +403,10 @@ def forward(self, inputs): | |||||
out = x[:, 1:] | ||||||
else: | ||||||
out = x | ||||||
if self.out_shape == 'NCHW': | ||||||
B, _, C = out.shape | ||||||
out = out.reshape(B, inputs.shape[2] // self.patch_size, | ||||||
inputs.shape[3] // self.patch_size, | ||||||
C).permute(0, 3, 1, 2) | ||||||
B, _, C = out.shape | ||||||
out = out.reshape(B, H, W, C).permute(0, 3, 1, 2) | ||||||
if self.output_cls_token: | ||||||
out = [out, x[:, 0]] | ||||||
outs.append(out) | ||||||
|
||||||
return tuple(outs) | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is better to add this description to Docstring.