Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix swin backbone absolute pos_embed #8127

Merged
merged 5 commits into from
Aug 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions mmdet/models/backbones/swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,9 +588,8 @@ def __init__(self,
if self.use_abs_pos_embed:
patch_row = pretrain_img_size[0] // patch_size
patch_col = pretrain_img_size[1] // patch_size
num_patches = patch_row * patch_col
self.absolute_pos_embed = nn.Parameter(
torch.zeros((1, num_patches, embed_dims)))
torch.zeros((1, embed_dims, patch_row, patch_col)))

self.drop_after_pos = nn.Dropout(p=drop_rate)

Expand Down Expand Up @@ -746,7 +745,17 @@ def forward(self, x):
x, hw_shape = self.patch_embed(x)

if self.use_abs_pos_embed:
x = x + self.absolute_pos_embed
h, w = self.absolute_pos_embed.shape[1:3]
if hw_shape[0] != h or hw_shape[1] != w:
absolute_pos_embed = F.interpolate(
self.absolute_pos_embed,
size=hw_shape,
mode='bicubic',
align_corners=False).flatten(2).transpose(1, 2)
else:
absolute_pos_embed = self.absolute_pos_embed.flatten(
2).transpose(1, 2)
x = x + absolute_pos_embed
x = self.drop_after_pos(x)

outs = []
Expand Down
5 changes: 5 additions & 0 deletions tests/test_models/test_backbones/test_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ def test_swin_transformer():
model = SwinTransformer(pretrain_img_size=224, use_abs_pos_embed=True)
model.init_weights()
model(temp)
# Test different inputs when use absolute position embedding
temp = torch.randn((1, 3, 112, 112))
model(temp)
temp = torch.randn((1, 3, 256, 256))
model(temp)

# Test patch norm
model = SwinTransformer(patch_norm=False)
Expand Down