Skip to content

Commit

Permalink
[Fix] Add setr & vit msg. (#635)
Browse files Browse the repository at this point in the history
* [Fix] Add setr & vit msg.

* Fix init bug

* Modify init_cfg arg

* Add conv_seg init
  • Loading branch information
clownrat6 authored Jun 24, 2021
1 parent ec91893 commit 98067be
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 12 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ Supported backbones:
- [x] [ResNeSt (ArXiv'2020)](configs/resnest/README.md)
- [x] [MobileNetV2 (CVPR'2018)](configs/mobilenet_v2/README.md)
- [x] [MobileNetV3 (ICCV'2019)](configs/mobilenet_v3/README.md)
- [x] [Vision Transformer (ICLR'2021)]

Supported methods:

Expand All @@ -89,6 +90,7 @@ Supported methods:
- [x] [DNLNet (ECCV'2020)](configs/dnlnet)
- [x] [PointRend (CVPR'2020)](configs/point_rend)
- [x] [CGNet (TIP'2020)](configs/cgnet)
- [x] [SETR (CVPR'2021)](configs/setr)

## Installation

Expand Down
2 changes: 2 additions & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
- [x] [ResNeSt (ArXiv'2020)](configs/resnest/README.md)
- [x] [MobileNetV2 (CVPR'2018)](configs/mobilenet_v2/README.md)
- [x] [MobileNetV3 (ICCV'2019)](configs/mobilenet_v3/README.md)
- [x] [Vision Transformer (ICLR'2021)]

已支持的算法:

Expand All @@ -87,6 +88,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
- [x] [DNLNet (ECCV'2020)](configs/dnlnet)
- [x] [PointRend (CVPR'2020)](configs/point_rend)
- [x] [CGNet (TIP'2020)](configs/cgnet)
- [x] [SETR (CVPR'2021)](configs/setr)

## 安装

Expand Down
22 changes: 13 additions & 9 deletions mmseg/models/decode_heads/setr_up_head.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch.nn as nn
from mmcv.cnn import ConvModule, build_norm_layer, constant_init
from mmcv.cnn import ConvModule, build_norm_layer

from ..builder import HEADS
from .decode_head import BaseDecodeHead
Expand All @@ -18,18 +18,28 @@ class SETRUPHead(BaseDecodeHead):
up_scale (int): The scale factor of interpolate. Default:4.
kernel_size (int): The kernel size of convolution when decoding
feature information from backbone. Default: 3.
init_cfg (dict | list[dict] | None): Initialization config dict.
Default: dict(
type='Constant', val=1.0, bias=0, layer='LayerNorm').
"""

def __init__(self,
norm_layer=dict(type='LN', eps=1e-6, requires_grad=True),
num_convs=1,
up_scale=4,
kernel_size=3,
init_cfg=[
dict(type='Constant', val=1.0, bias=0, layer='LayerNorm'),
dict(
type='Normal',
std=0.01,
override=dict(name='conv_seg'))
],
**kwargs):

assert kernel_size in [1, 3], 'kernel_size must be 1 or 3.'

super(SETRUPHead, self).__init__(**kwargs)
super(SETRUPHead, self).__init__(init_cfg=init_cfg, **kwargs)

assert isinstance(self.in_channels, int)

Expand All @@ -38,7 +48,7 @@ def __init__(self,
self.up_convs = nn.ModuleList()
in_channels = self.in_channels
out_channels = self.channels
for i in range(num_convs):
for _ in range(num_convs):
self.up_convs.append(
nn.Sequential(
ConvModule(
Expand All @@ -55,12 +65,6 @@ def __init__(self,
align_corners=self.align_corners)))
in_channels = out_channels

def init_weights(self):
for m in self.modules():
if isinstance(m, nn.LayerNorm):
constant_init(m.bias, 0)
constant_init(m.weight, 1.0)

def forward(self, x):
x = self._transform_inputs(x)

Expand Down
7 changes: 4 additions & 3 deletions tests/test_models/test_heads/test_setr_up_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@ def test_setr_up_head(capsys):
# as embed_dim.
SETRUPHead(in_channels=(32, 32), channels=16, num_classes=19)

# test init_weights of head
# test init_cfg of head
head = SETRUPHead(
in_channels=32,
channels=16,
norm_cfg=dict(type='SyncBN'),
num_classes=19)
head.init_weights()
num_classes=19,
init_cfg=dict(type='Kaiming'))
super(SETRUPHead, head).init_weights()

# test inference of Naive head
# the auxiliary head of Naive head is same as Naive head
Expand Down

0 comments on commit 98067be

Please sign in to comment.