-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Refactor] Using mmcv transformer bricks to refactor vit. #571
Conversation
Codecov Report
@@ Coverage Diff @@
## master #571 +/- ##
==========================================
- Coverage 85.95% 85.45% -0.50%
==========================================
Files 101 101
Lines 5234 5220 -14
Branches 828 840 +12
==========================================
- Hits 4499 4461 -38
- Misses 561 586 +25
+ Partials 174 173 -1
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
@@ -0,0 +1,53 @@ | |||
import logging |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this file necessary?
mmseg/models/utils/helpers.py
Outdated
import collections.abc | ||
from itertools import repeat | ||
|
||
|
||
# From PyTorch internals | ||
def _ntuple(n): | ||
|
||
def parse(x): | ||
if isinstance(x, collections.abc.Iterable): | ||
return x | ||
return tuple(repeat(x, n)) | ||
|
||
return parse | ||
|
||
|
||
to_1tuple = _ntuple(1) | ||
to_2tuple = _ntuple(2) | ||
to_3tuple = _ntuple(3) | ||
to_4tuple = _ntuple(4) | ||
to_ntuple = _ntuple |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file is not necessary. Use from torch.nn.modules.utils import _pair as to_2tuple
instead.
mmseg/models/backbones/vit.py
Outdated
# We only implement the 'jax_impl' initialization implemented at | ||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501 | ||
trunc_normal_(self.pos_embed, std=.02) | ||
trunc_normal_(self.cls_token, std=.02) | ||
for n, m in self.named_modules(): | ||
if isinstance(m, Linear): | ||
trunc_normal_(m.weight, std=.02) | ||
if m.bias is not None: | ||
if 'mlp' in n: | ||
normal_init(m.bias, std=1e-6) | ||
else: | ||
constant_init(m.bias, 0) | ||
elif isinstance(m, Conv2d): | ||
kaiming_init(m.weight, mode='fan_in') | ||
if m.bias is not None: | ||
constant_init(m.bias, 0) | ||
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): | ||
constant_init(m.bias, 0) | ||
constant_init(m.weight, 1.0) | ||
else: | ||
raise TypeError('pretrained must be a str or None') | ||
# Modified from ClassyVision | ||
nn.init.normal_(self.pos_embed, std=0.02) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why remove the initialization?
1. Use timm style init_weights; 2. Remove to_xtuple and trunc_norm_;
mmseg/models/backbones/vit.py
Outdated
@@ -330,10 +299,17 @@ def init_weights(self, pretrained=None): | |||
else: | |||
state_dict = checkpoint | |||
|
|||
if 'rwightman/pytorch-image-models' in pretrained: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If user downloaded the weight from timm and would like to init the model with path, the condition does not hold.
.github/workflows/build.yml
Outdated
@@ -37,19 +37,26 @@ jobs: | |||
include: | |||
- torch: 1.3.0+cpu | |||
torchvision: 0.4.1+cpu | |||
torch_version: 1.3.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reset this file.
with_cp (bool): Use checkpoint or not. Using checkpoint will save | ||
some memory while slowing down the training speed. Default: False. | ||
pretrain_style (str): Choose to use timm or mmcls pretrain weights. | ||
Default: timm. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should explain what options are supported, and add assert.
…#571) * [Refactor] Using mmcv bricks to refactor vit * Follow the vit code structure from mmclassification * Add MMCV install into CI system. * Add to 'Install MMCV' CI item * Add 'Install MMCV_CPU' and 'Install MMCV_GPU CI' items * Fix & Add 1. Fix low code coverage of vit.py; 2. Remove HybirdEmbed; 3. Fix doc string of VisionTransformer; * Add helpers unit test. * Add converter to convert vit pretrain weights from timm style to mmcls style. * Clean some rebundant code and refactor init 1. Use timm style init_weights; 2. Remove to_xtuple and trunc_norm_; * Add comments for VisionTransformer.init_weights() * Add arg: pretrain_style to choose timm or mmcls vit pretrain weights.
* add atrw dataset * add atrw configs * add animal readme * add atrw * update log interval * update readme * update readme * update init
…en-mmlab#571) * polish README * fix typo
The foundation of this PR:
mmcv: open-mmlab/mmcv#978 (merged)
mmclassification: open-mmlab/mmpretrain#295