Skip to content

Commit b8f42c7

Browse files
authored
Add Semantic FPN (#94)
* Add Semantic FPN * remove HRFPN
1 parent 597b8a6 commit b8f42c7

14 files changed

+388
-37
lines changed

configs/_base_/models/fpn_r50.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# model settings
2+
norm_cfg = dict(type='SyncBN', requires_grad=True)
3+
model = dict(
4+
type='EncoderDecoder',
5+
pretrained='open-mmlab://resnet50_v1c',
6+
backbone=dict(
7+
type='ResNetV1c',
8+
depth=50,
9+
num_stages=4,
10+
out_indices=(0, 1, 2, 3),
11+
dilations=(1, 1, 1, 1),
12+
strides=(1, 2, 2, 2),
13+
norm_cfg=norm_cfg,
14+
norm_eval=False,
15+
style='pytorch',
16+
contract_dilation=True),
17+
neck=dict(
18+
type='FPN',
19+
in_channels=[256, 512, 1024, 2048],
20+
out_channels=256,
21+
num_outs=4),
22+
decode_head=dict(
23+
type='FPNHead',
24+
in_channels=[256, 256, 256, 256],
25+
in_index=[0, 1, 2, 3],
26+
feature_strides=[4, 8, 16, 32],
27+
channels=128,
28+
dropout_ratio=0.1,
29+
num_classes=19,
30+
norm_cfg=norm_cfg,
31+
align_corners=False,
32+
loss_decode=dict(
33+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)))
34+
# model training and testing settings
35+
train_cfg = dict()
36+
test_cfg = dict(mode='whole')

configs/sem_fpn/README.md

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Panoptic Feature Pyramid Networks
2+
3+
## Introduction
4+
```
5+
@article{Kirillov_2019,
6+
title={Panoptic Feature Pyramid Networks},
7+
ISBN={9781728132938},
8+
url={http://dx.doi.org/10.1109/CVPR.2019.00656},
9+
DOI={10.1109/cvpr.2019.00656},
10+
journal={2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
11+
publisher={IEEE},
12+
author={Kirillov, Alexander and Girshick, Ross and He, Kaiming and Dollar, Piotr},
13+
year={2019},
14+
month={Jun}
15+
}
16+
```
17+
18+
## Results and models
19+
20+
### Cityscapes
21+
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
22+
|--------|----------|-----------|--------:|---------:|----------------|------:|---------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
23+
| FPN | R-50 | 512x1024 | 80000 | 2.8 | 13.54 | 74.52 | 76.08 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r50_512x1024_80k_cityscapes/fpn_r50_512x1024_80k_cityscapes_20200717_021437-94018a0d.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r50_512x1024_80k_cityscapes/fpn_r50_512x1024_80k_cityscapes-20200717_021437.log.json) |
24+
| FPN | R-101 | 512x1024 | 80000 | 3.9 | 10.29 | 75.80 | 77.40 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r101_512x1024_80k_cityscapes/fpn_r101_512x1024_80k_cityscapes_20200717_012416-c5800d4c.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r101_512x1024_80k_cityscapes/fpn_r101_512x1024_80k_cityscapes-20200717_012416.log.json) |
25+
26+
### ADE20K
27+
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
28+
|--------|----------|-----------|--------:|---------:|----------------|------:|---------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
29+
| FPN | R-50 | 512x512 | 160000 | 4.9 | 55.77 | 37.49 | 39.09 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r50_512x512_160k_ade20k/fpn_r50_512x512_160k_ade20k_20200718_131734-5b5a6ab9.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r50_512x512_160k_ade20k/fpn_r50_512x512_160k_ade20k-20200718_131734.log.json) |
30+
| FPN | R-101 | 512x512 | 160000 | 5.9 | 40.58 | 39.35 | 40.72 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r101_512x512_160k_ade20k/fpn_r101_512x512_160k_ade20k_20200718_131734-306b5004.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r101_512x512_160k_ade20k/fpn_r101_512x512_160k_ade20k-20200718_131734.log.json) |
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
_base_ = './fpn_r50_512x1024_80k_cityscapes.py'
2+
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
_base_ = './fpn_r50_512x512_160k_ade20k.py'
2+
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
_base_ = [
2+
'../_base_/models/fpn_r50.py', '../_base_/datasets/cityscapes.py',
3+
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
4+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
_base_ = [
2+
'../_base_/models/fpn_r50.py', '../_base_/datasets/ade20k.py',
3+
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
4+
]
5+
model = dict(decode_head=dict(num_classes=150))

mmseg/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
build_head, build_loss, build_segmentor)
44
from .decode_heads import * # noqa: F401,F403
55
from .losses import * # noqa: F401,F403
6+
from .necks import * # noqa: F401,F403
67
from .segmentors import * # noqa: F401,F403
78

89
__all__ = [

mmseg/models/decode_heads/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .da_head import DAHead
55
from .enc_head import EncHead
66
from .fcn_head import FCNHead
7+
from .fpn_head import FPNHead
78
from .gc_head import GCHead
89
from .nl_head import NLHead
910
from .ocr_head import OCRHead
@@ -16,5 +17,5 @@
1617
__all__ = [
1718
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
1819
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
19-
'EncHead', 'DepthwiseSeparableFCNHead'
20+
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead'
2021
]

mmseg/models/decode_heads/fpn_head.py

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import numpy as np
2+
import torch.nn as nn
3+
from mmcv.cnn import ConvModule
4+
5+
from mmseg.ops import resize
6+
from ..builder import HEADS
7+
from .decode_head import BaseDecodeHead
8+
9+
10+
@HEADS.register_module()
11+
class FPNHead(BaseDecodeHead):
12+
"""Panoptic Feature Pyramid Networks.
13+
14+
This head is the implementation of `Semantic FPN
15+
<https://arxiv.org/abs/1901.02446>`_.
16+
17+
Args:
18+
feature_strides (tuple[int]): The strides for input feature maps.
19+
stack_lateral. All strides suppose to be power of 2. The first
20+
one is of largest resolution.
21+
"""
22+
23+
def __init__(self, feature_strides, **kwargs):
24+
super(FPNHead, self).__init__(
25+
input_transform='multiple_select', **kwargs)
26+
assert len(feature_strides) == len(self.in_channels)
27+
assert min(feature_strides) == feature_strides[0]
28+
self.feature_strides = feature_strides
29+
30+
self.scale_heads = nn.ModuleList()
31+
for i in range(len(feature_strides)):
32+
head_length = max(
33+
1,
34+
int(np.log2(feature_strides[i]) - np.log2(feature_strides[0])))
35+
scale_head = []
36+
for k in range(head_length):
37+
scale_head.append(
38+
ConvModule(
39+
self.in_channels[i] if k == 0 else self.channels,
40+
self.channels,
41+
3,
42+
padding=1,
43+
conv_cfg=self.conv_cfg,
44+
norm_cfg=self.norm_cfg,
45+
act_cfg=self.act_cfg))
46+
if feature_strides[i] != feature_strides[0]:
47+
scale_head.append(
48+
nn.Upsample(
49+
scale_factor=2,
50+
mode='bilinear',
51+
align_corners=self.align_corners))
52+
self.scale_heads.append(nn.Sequential(*scale_head))
53+
54+
def forward(self, inputs):
55+
56+
x = self._transform_inputs(inputs)
57+
58+
output = self.scale_heads[0](x[0])
59+
for i in range(1, len(self.feature_strides)):
60+
# non inplace
61+
output = output + resize(
62+
self.scale_heads[i](x[i]),
63+
size=output.shape[2:],
64+
mode='bilinear',
65+
align_corners=self.align_corners)
66+
67+
output = self.cls_seg(output)
68+
return output

mmseg/models/necks/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .fpn import FPN
2+
3+
__all__ = ['FPN']

mmseg/models/necks/fpn.py

+212
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
import torch.nn as nn
2+
import torch.nn.functional as F
3+
from mmcv.cnn import ConvModule, xavier_init
4+
5+
from ..builder import NECKS
6+
7+
8+
@NECKS.register_module()
9+
class FPN(nn.Module):
10+
"""Feature Pyramid Network.
11+
12+
This is an implementation of - Feature Pyramid Networks for Object
13+
Detection (https://arxiv.org/abs/1612.03144)
14+
15+
Args:
16+
in_channels (List[int]): Number of input channels per scale.
17+
out_channels (int): Number of output channels (used at each scale)
18+
num_outs (int): Number of output scales.
19+
start_level (int): Index of the start input backbone level used to
20+
build the feature pyramid. Default: 0.
21+
end_level (int): Index of the end input backbone level (exclusive) to
22+
build the feature pyramid. Default: -1, which means the last level.
23+
add_extra_convs (bool | str): If bool, it decides whether to add conv
24+
layers on top of the original feature maps. Default to False.
25+
If True, its actual mode is specified by `extra_convs_on_inputs`.
26+
If str, it specifies the source feature map of the extra convs.
27+
Only the following options are allowed
28+
29+
- 'on_input': Last feat map of neck inputs (i.e. backbone feature).
30+
- 'on_lateral': Last feature map after lateral convs.
31+
- 'on_output': The last output feature map after fpn convs.
32+
extra_convs_on_inputs (bool, deprecated): Whether to apply extra convs
33+
on the original feature from the backbone. If True,
34+
it is equivalent to `add_extra_convs='on_input'`. If False, it is
35+
equivalent to set `add_extra_convs='on_output'`. Default to True.
36+
relu_before_extra_convs (bool): Whether to apply relu before the extra
37+
conv. Default: False.
38+
no_norm_on_lateral (bool): Whether to apply norm on lateral.
39+
Default: False.
40+
conv_cfg (dict): Config dict for convolution layer. Default: None.
41+
norm_cfg (dict): Config dict for normalization layer. Default: None.
42+
act_cfg (str): Config dict for activation layer in ConvModule.
43+
Default: None.
44+
upsample_cfg (dict): Config dict for interpolate layer.
45+
Default: `dict(mode='nearest')`
46+
47+
Example:
48+
>>> import torch
49+
>>> in_channels = [2, 3, 5, 7]
50+
>>> scales = [340, 170, 84, 43]
51+
>>> inputs = [torch.rand(1, c, s, s)
52+
... for c, s in zip(in_channels, scales)]
53+
>>> self = FPN(in_channels, 11, len(in_channels)).eval()
54+
>>> outputs = self.forward(inputs)
55+
>>> for i in range(len(outputs)):
56+
... print(f'outputs[{i}].shape = {outputs[i].shape}')
57+
outputs[0].shape = torch.Size([1, 11, 340, 340])
58+
outputs[1].shape = torch.Size([1, 11, 170, 170])
59+
outputs[2].shape = torch.Size([1, 11, 84, 84])
60+
outputs[3].shape = torch.Size([1, 11, 43, 43])
61+
"""
62+
63+
def __init__(self,
64+
in_channels,
65+
out_channels,
66+
num_outs,
67+
start_level=0,
68+
end_level=-1,
69+
add_extra_convs=False,
70+
extra_convs_on_inputs=False,
71+
relu_before_extra_convs=False,
72+
no_norm_on_lateral=False,
73+
conv_cfg=None,
74+
norm_cfg=None,
75+
act_cfg=None,
76+
upsample_cfg=dict(mode='nearest')):
77+
super(FPN, self).__init__()
78+
assert isinstance(in_channels, list)
79+
self.in_channels = in_channels
80+
self.out_channels = out_channels
81+
self.num_ins = len(in_channels)
82+
self.num_outs = num_outs
83+
self.relu_before_extra_convs = relu_before_extra_convs
84+
self.no_norm_on_lateral = no_norm_on_lateral
85+
self.fp16_enabled = False
86+
self.upsample_cfg = upsample_cfg.copy()
87+
88+
if end_level == -1:
89+
self.backbone_end_level = self.num_ins
90+
assert num_outs >= self.num_ins - start_level
91+
else:
92+
# if end_level < inputs, no extra level is allowed
93+
self.backbone_end_level = end_level
94+
assert end_level <= len(in_channels)
95+
assert num_outs == end_level - start_level
96+
self.start_level = start_level
97+
self.end_level = end_level
98+
self.add_extra_convs = add_extra_convs
99+
assert isinstance(add_extra_convs, (str, bool))
100+
if isinstance(add_extra_convs, str):
101+
# Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output'
102+
assert add_extra_convs in ('on_input', 'on_lateral', 'on_output')
103+
elif add_extra_convs: # True
104+
if extra_convs_on_inputs:
105+
# For compatibility with previous release
106+
# TODO: deprecate `extra_convs_on_inputs`
107+
self.add_extra_convs = 'on_input'
108+
else:
109+
self.add_extra_convs = 'on_output'
110+
111+
self.lateral_convs = nn.ModuleList()
112+
self.fpn_convs = nn.ModuleList()
113+
114+
for i in range(self.start_level, self.backbone_end_level):
115+
l_conv = ConvModule(
116+
in_channels[i],
117+
out_channels,
118+
1,
119+
conv_cfg=conv_cfg,
120+
norm_cfg=norm_cfg if not self.no_norm_on_lateral else None,
121+
act_cfg=act_cfg,
122+
inplace=False)
123+
fpn_conv = ConvModule(
124+
out_channels,
125+
out_channels,
126+
3,
127+
padding=1,
128+
conv_cfg=conv_cfg,
129+
norm_cfg=norm_cfg,
130+
act_cfg=act_cfg,
131+
inplace=False)
132+
133+
self.lateral_convs.append(l_conv)
134+
self.fpn_convs.append(fpn_conv)
135+
136+
# add extra conv layers (e.g., RetinaNet)
137+
extra_levels = num_outs - self.backbone_end_level + self.start_level
138+
if self.add_extra_convs and extra_levels >= 1:
139+
for i in range(extra_levels):
140+
if i == 0 and self.add_extra_convs == 'on_input':
141+
in_channels = self.in_channels[self.backbone_end_level - 1]
142+
else:
143+
in_channels = out_channels
144+
extra_fpn_conv = ConvModule(
145+
in_channels,
146+
out_channels,
147+
3,
148+
stride=2,
149+
padding=1,
150+
conv_cfg=conv_cfg,
151+
norm_cfg=norm_cfg,
152+
act_cfg=act_cfg,
153+
inplace=False)
154+
self.fpn_convs.append(extra_fpn_conv)
155+
156+
# default init_weights for conv(msra) and norm in ConvModule
157+
def init_weights(self):
158+
for m in self.modules():
159+
if isinstance(m, nn.Conv2d):
160+
xavier_init(m, distribution='uniform')
161+
162+
def forward(self, inputs):
163+
assert len(inputs) == len(self.in_channels)
164+
165+
# build laterals
166+
laterals = [
167+
lateral_conv(inputs[i + self.start_level])
168+
for i, lateral_conv in enumerate(self.lateral_convs)
169+
]
170+
171+
# build top-down path
172+
used_backbone_levels = len(laterals)
173+
for i in range(used_backbone_levels - 1, 0, -1):
174+
# In some cases, fixing `scale factor` (e.g. 2) is preferred, but
175+
# it cannot co-exist with `size` in `F.interpolate`.
176+
if 'scale_factor' in self.upsample_cfg:
177+
laterals[i - 1] += F.interpolate(laterals[i],
178+
**self.upsample_cfg)
179+
else:
180+
prev_shape = laterals[i - 1].shape[2:]
181+
laterals[i - 1] += F.interpolate(
182+
laterals[i], size=prev_shape, **self.upsample_cfg)
183+
184+
# build outputs
185+
# part 1: from original levels
186+
outs = [
187+
self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
188+
]
189+
# part 2: add extra levels
190+
if self.num_outs > len(outs):
191+
# use max pool to get more levels on top of outputs
192+
# (e.g., Faster R-CNN, Mask R-CNN)
193+
if not self.add_extra_convs:
194+
for i in range(self.num_outs - used_backbone_levels):
195+
outs.append(F.max_pool2d(outs[-1], 1, stride=2))
196+
# add conv layers on top of original feature maps (RetinaNet)
197+
else:
198+
if self.add_extra_convs == 'on_input':
199+
extra_source = inputs[self.backbone_end_level - 1]
200+
elif self.add_extra_convs == 'on_lateral':
201+
extra_source = laterals[-1]
202+
elif self.add_extra_convs == 'on_output':
203+
extra_source = outs[-1]
204+
else:
205+
raise NotImplementedError
206+
outs.append(self.fpn_convs[used_backbone_levels](extra_source))
207+
for i in range(used_backbone_levels + 1, self.num_outs):
208+
if self.relu_before_extra_convs:
209+
outs.append(self.fpn_convs[i](F.relu(outs[-1])))
210+
else:
211+
outs.append(self.fpn_convs[i](outs[-1]))
212+
return tuple(outs)

0 commit comments

Comments
 (0)