-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Add Swin Transformer #511
Conversation
|
mmcv_custom/checkpoint.py
Outdated
@@ -0,0 +1,496 @@ | |||
# Copyright (c) Open-MMLab. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may remove this file and directory.
init_cfg=None, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
init_cfg=None, | |
) | |
init_cfg=None) |
stride = to_2tuple(stride) | ||
padding = to_2tuple(padding) | ||
dilation = to_2tuple(dilation) | ||
self.sampler = nn.Unfold(kernel_size, dilation, padding, stride) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The padding may need to calculate by users.
|
||
|
||
@ATTENTION.register_module() | ||
class ShiftWindowMSA(BaseModule): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing docstring.
class ShiftWindowMSA(BaseModule): | ||
|
||
def __init__(self, | ||
input_resolution, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In our setting, the input size may change during infer or training. We should determine input size when initializing the module.
return windows | ||
|
||
|
||
class SwinBlock(BaseModule): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing docstring.
class SwinBlock(BaseModule): | ||
|
||
def __init__(self, | ||
input_resolution, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly, input_size should be unknown.
def forward(self, query): | ||
for block in self.blocks: | ||
query = block(query) | ||
|
||
if self.downsample: | ||
query = self.downsample(query) | ||
return query |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest pack (H, W) into hw_shape
, and forward it also.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, the H, W are wrapped in the attributes of PatchMerging. H, W = self.output_resolution.
Codecov Report
@@ Coverage Diff @@
## master #511 +/- ##
==========================================
- Coverage 85.77% 85.14% -0.63%
==========================================
Files 103 105 +2
Lines 5307 5668 +361
Branches 857 923 +66
==========================================
+ Hits 4552 4826 +274
- Misses 583 663 +80
- Partials 172 179 +7
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
mmseg/models/backbones/swin.py
Outdated
class ShiftWindowMSA(BaseModule): | ||
|
||
def __init__(self, | ||
input_resolution, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't have input_resolution
arg.
Reference: https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/87e6f90577435c94f3e92c7db1d36edc234d91f6/mmseg/models/backbones/swin_transformer.py#L156
mmseg/models/utils/embed.py
Outdated
@@ -0,0 +1,91 @@ | |||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modified from xxx.
mmseg/models/backbones/swin.py
Outdated
else: | ||
self.downsample = None | ||
|
||
def forward(self, x, H, W): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may pack H, W
into hw_shape
as a tuple.
Reference https://github.com/facebookresearch/SlowFast/blob/2090f2918ac1ce890fdacd8fda2e590a46d5c734/slowfast/models/video_model_builder.py#L1002
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We also need a docstring for this.
mmseg/models/backbones/swin.py
Outdated
if self.downsample: | ||
stage_out = x | ||
x = self.downsample(x, H, W) | ||
DH, DW = (H + 1) // 2, (W + 1) // 2 | ||
return stage_out, H, W, x, DH, DW | ||
else: | ||
return x, H, W, x, H, W |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The output should be x, hw_shape
.
We don't need the previous hw_shape
.
mmseg/models/backbones/swin.py
Outdated
stride=None, | ||
padding=0, | ||
dilation=1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove these args. Make stride=kernel_size.
2. Correct weight convert function; 3. Fix the pad of Patch Merging;
mmseg/models/backbones/swin.py
Outdated
def __init__(self, | ||
in_channels, | ||
out_channels, | ||
kernel_size=2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may remove this.
mmseg/models/backbones/swin.py
Outdated
mlp_ratio=4, | ||
depths=(2, 2, 6, 2), | ||
num_heads=(3, 6, 12, 24), | ||
strides=(None, None, None, None), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the default is none?
mmseg/models/backbones/swin.py
Outdated
num_heads (int): Parallel attention heads. | ||
feedforward_channels (int): The hidden dimension for FFNs. | ||
depth (int): The number of blocks in this stage. | ||
kernel_size (int): The kernel_size of patch merging. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may also remove this. Use stride
only.
mmseg/models/backbones/swin.py
Outdated
padding (int): The padding length of patch merging. | ||
dilation (int): The dilation rate of kernel of patch merging. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not needed.
mmseg/models/backbones/swin.py
Outdated
kernel_size, | ||
stride, | ||
padding, | ||
dilation, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keep stride only.
2. Fix some pad bug; 3. Modify config to adapt new swin implementation;
mmseg/models/backbones/swin.py
Outdated
paddings (tuple[int], optional): The patch merging or patch | ||
embedding padding length of each Swin Transformer stage. | ||
Default: (0, 0, 0, 0). | ||
dilations (tuple[int], optional): The patch merging or patch | ||
embedding kernel dilation rate of each Swin Transformer stage. | ||
Default: (1, 1, 1, 1). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are no longer needed.
mmseg/models/backbones/swin.py
Outdated
if downsample: | ||
in_channels = in_channels * 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move to the definition of downample
.
2. Modify pth url which keep meta attribute;
…into swin_transformer
Add ViT link
* add Swin Transformer * add Swin Transformer * fixed import * Add some swin training settings. * Fix some filename error. * Fix attribute name: pretrain -> pretrained * Upload mmcls implementation of swin transformer. * Refactor Swin Transformer to follow mmcls style. * Refactor init_weigths of swin_transformer.py * Fix lint * Match inference precision * Add some comments * Add swin_convert to load official style ckpt * Remove arg: auto_pad * 1. Complete comments for each block; 2. Correct weight convert function; 3. Fix the pad of Patch Merging; * Clean function args. * Fix vit unit test. * 1. Add swin transformer unit tests; 2. Fix some pad bug; 3. Modify config to adapt new swin implementation; * Modify config arg * Update readme.md of swin * Fix config arg error and Add some swin benchmark msg. * Add MeM and ms test content for readme.md of swin transformer. * Fix doc string of swin module * 1. Register swin transformer to model list; 2. Modify pth url which keep meta attribute; * Update swin.py * Merge config settings. * Modify config style. * Update README.md Add ViT link * Modify main readme.md Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com> Co-authored-by: sennnnn <201730271412@mail.scut.edu.cn> Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn>
…ful). (open-mmlab#511) * Removing `autocast` for `35-25% speedup`. * iQuality * Adding a slow test. * Fixing mps noise generation. * Raising error on wrong device, instead of just casting on behalf of user. * Quality. * fix merge Co-authored-by: Nouamane Tazi <nouamane98@gmail.com>
* resolve comments * update changelog * add test_batch * add testing for `test_batch` * fix mmcv version * add test_batch * add testing for `test_batch` * enlarge test_input to pass unittest * update names * update changelog & faq * update name
No description provided.