Skip to content

Commit

Permalink
Add Semantic FPN (#94)
Browse files Browse the repository at this point in the history
* Add Semantic FPN

* remove HRFPN
  • Loading branch information
xvjiarui committed Sep 3, 2020
1 parent 597b8a6 commit b8f42c7
Show file tree
Hide file tree
Showing 14 changed files with 388 additions and 37 deletions.
36 changes: 36 additions & 0 deletions configs/_base_/models/fpn_r50.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 1, 1),
strides=(1, 2, 2, 2),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=4),
decode_head=dict(
type='FPNHead',
in_channels=[256, 256, 256, 256],
in_index=[0, 1, 2, 3],
feature_strides=[4, 8, 16, 32],
channels=128,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)))
# model training and testing settings
train_cfg = dict()
test_cfg = dict(mode='whole')
30 changes: 30 additions & 0 deletions configs/sem_fpn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Panoptic Feature Pyramid Networks

## Introduction
```
@article{Kirillov_2019,
title={Panoptic Feature Pyramid Networks},
ISBN={9781728132938},
url={http://dx.doi.org/10.1109/CVPR.2019.00656},
DOI={10.1109/cvpr.2019.00656},
journal={2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
publisher={IEEE},
author={Kirillov, Alexander and Girshick, Ross and He, Kaiming and Dollar, Piotr},
year={2019},
month={Jun}
}
```

## Results and models

### Cityscapes
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
|--------|----------|-----------|--------:|---------:|----------------|------:|---------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| FPN | R-50 | 512x1024 | 80000 | 2.8 | 13.54 | 74.52 | 76.08 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r50_512x1024_80k_cityscapes/fpn_r50_512x1024_80k_cityscapes_20200717_021437-94018a0d.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r50_512x1024_80k_cityscapes/fpn_r50_512x1024_80k_cityscapes-20200717_021437.log.json) |
| FPN | R-101 | 512x1024 | 80000 | 3.9 | 10.29 | 75.80 | 77.40 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r101_512x1024_80k_cityscapes/fpn_r101_512x1024_80k_cityscapes_20200717_012416-c5800d4c.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r101_512x1024_80k_cityscapes/fpn_r101_512x1024_80k_cityscapes-20200717_012416.log.json) |

### ADE20K
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
|--------|----------|-----------|--------:|---------:|----------------|------:|---------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| FPN | R-50 | 512x512 | 160000 | 4.9 | 55.77 | 37.49 | 39.09 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r50_512x512_160k_ade20k/fpn_r50_512x512_160k_ade20k_20200718_131734-5b5a6ab9.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r50_512x512_160k_ade20k/fpn_r50_512x512_160k_ade20k-20200718_131734.log.json) |
| FPN | R-101 | 512x512 | 160000 | 5.9 | 40.58 | 39.35 | 40.72 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r101_512x512_160k_ade20k/fpn_r101_512x512_160k_ade20k_20200718_131734-306b5004.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/sem_fpn/fpn_r101_512x512_160k_ade20k/fpn_r101_512x512_160k_ade20k-20200718_131734.log.json) |
2 changes: 2 additions & 0 deletions configs/sem_fpn/fpn_r101_512x1024_80k_cityscapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
_base_ = './fpn_r50_512x1024_80k_cityscapes.py'
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
2 changes: 2 additions & 0 deletions configs/sem_fpn/fpn_r101_512x512_160k_ade20k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
_base_ = './fpn_r50_512x512_160k_ade20k.py'
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
4 changes: 4 additions & 0 deletions configs/sem_fpn/fpn_r50_512x1024_80k_cityscapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_base_ = [
'../_base_/models/fpn_r50.py', '../_base_/datasets/cityscapes.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
]
5 changes: 5 additions & 0 deletions configs/sem_fpn/fpn_r50_512x512_160k_ade20k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_base_ = [
'../_base_/models/fpn_r50.py', '../_base_/datasets/ade20k.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
model = dict(decode_head=dict(num_classes=150))
1 change: 1 addition & 0 deletions mmseg/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
build_head, build_loss, build_segmentor)
from .decode_heads import * # noqa: F401,F403
from .losses import * # noqa: F401,F403
from .necks import * # noqa: F401,F403
from .segmentors import * # noqa: F401,F403

__all__ = [
Expand Down
3 changes: 2 additions & 1 deletion mmseg/models/decode_heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .da_head import DAHead
from .enc_head import EncHead
from .fcn_head import FCNHead
from .fpn_head import FPNHead
from .gc_head import GCHead
from .nl_head import NLHead
from .ocr_head import OCRHead
Expand All @@ -16,5 +17,5 @@
__all__ = [
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
'EncHead', 'DepthwiseSeparableFCNHead'
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead'
]
68 changes: 68 additions & 0 deletions mmseg/models/decode_heads/fpn_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import numpy as np
import torch.nn as nn
from mmcv.cnn import ConvModule

from mmseg.ops import resize
from ..builder import HEADS
from .decode_head import BaseDecodeHead


@HEADS.register_module()
class FPNHead(BaseDecodeHead):
"""Panoptic Feature Pyramid Networks.
This head is the implementation of `Semantic FPN
<https://arxiv.org/abs/1901.02446>`_.
Args:
feature_strides (tuple[int]): The strides for input feature maps.
stack_lateral. All strides suppose to be power of 2. The first
one is of largest resolution.
"""

def __init__(self, feature_strides, **kwargs):
super(FPNHead, self).__init__(
input_transform='multiple_select', **kwargs)
assert len(feature_strides) == len(self.in_channels)
assert min(feature_strides) == feature_strides[0]
self.feature_strides = feature_strides

self.scale_heads = nn.ModuleList()
for i in range(len(feature_strides)):
head_length = max(
1,
int(np.log2(feature_strides[i]) - np.log2(feature_strides[0])))
scale_head = []
for k in range(head_length):
scale_head.append(
ConvModule(
self.in_channels[i] if k == 0 else self.channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
if feature_strides[i] != feature_strides[0]:
scale_head.append(
nn.Upsample(
scale_factor=2,
mode='bilinear',
align_corners=self.align_corners))
self.scale_heads.append(nn.Sequential(*scale_head))

def forward(self, inputs):

x = self._transform_inputs(inputs)

output = self.scale_heads[0](x[0])
for i in range(1, len(self.feature_strides)):
# non inplace
output = output + resize(
self.scale_heads[i](x[i]),
size=output.shape[2:],
mode='bilinear',
align_corners=self.align_corners)

output = self.cls_seg(output)
return output
3 changes: 3 additions & 0 deletions mmseg/models/necks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .fpn import FPN

__all__ = ['FPN']
212 changes: 212 additions & 0 deletions mmseg/models/necks/fpn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule, xavier_init

from ..builder import NECKS


@NECKS.register_module()
class FPN(nn.Module):
"""Feature Pyramid Network.
This is an implementation of - Feature Pyramid Networks for Object
Detection (https://arxiv.org/abs/1612.03144)
Args:
in_channels (List[int]): Number of input channels per scale.
out_channels (int): Number of output channels (used at each scale)
num_outs (int): Number of output scales.
start_level (int): Index of the start input backbone level used to
build the feature pyramid. Default: 0.
end_level (int): Index of the end input backbone level (exclusive) to
build the feature pyramid. Default: -1, which means the last level.
add_extra_convs (bool | str): If bool, it decides whether to add conv
layers on top of the original feature maps. Default to False.
If True, its actual mode is specified by `extra_convs_on_inputs`.
If str, it specifies the source feature map of the extra convs.
Only the following options are allowed
- 'on_input': Last feat map of neck inputs (i.e. backbone feature).
- 'on_lateral': Last feature map after lateral convs.
- 'on_output': The last output feature map after fpn convs.
extra_convs_on_inputs (bool, deprecated): Whether to apply extra convs
on the original feature from the backbone. If True,
it is equivalent to `add_extra_convs='on_input'`. If False, it is
equivalent to set `add_extra_convs='on_output'`. Default to True.
relu_before_extra_convs (bool): Whether to apply relu before the extra
conv. Default: False.
no_norm_on_lateral (bool): Whether to apply norm on lateral.
Default: False.
conv_cfg (dict): Config dict for convolution layer. Default: None.
norm_cfg (dict): Config dict for normalization layer. Default: None.
act_cfg (str): Config dict for activation layer in ConvModule.
Default: None.
upsample_cfg (dict): Config dict for interpolate layer.
Default: `dict(mode='nearest')`
Example:
>>> import torch
>>> in_channels = [2, 3, 5, 7]
>>> scales = [340, 170, 84, 43]
>>> inputs = [torch.rand(1, c, s, s)
... for c, s in zip(in_channels, scales)]
>>> self = FPN(in_channels, 11, len(in_channels)).eval()
>>> outputs = self.forward(inputs)
>>> for i in range(len(outputs)):
... print(f'outputs[{i}].shape = {outputs[i].shape}')
outputs[0].shape = torch.Size([1, 11, 340, 340])
outputs[1].shape = torch.Size([1, 11, 170, 170])
outputs[2].shape = torch.Size([1, 11, 84, 84])
outputs[3].shape = torch.Size([1, 11, 43, 43])
"""

def __init__(self,
in_channels,
out_channels,
num_outs,
start_level=0,
end_level=-1,
add_extra_convs=False,
extra_convs_on_inputs=False,
relu_before_extra_convs=False,
no_norm_on_lateral=False,
conv_cfg=None,
norm_cfg=None,
act_cfg=None,
upsample_cfg=dict(mode='nearest')):
super(FPN, self).__init__()
assert isinstance(in_channels, list)
self.in_channels = in_channels
self.out_channels = out_channels
self.num_ins = len(in_channels)
self.num_outs = num_outs
self.relu_before_extra_convs = relu_before_extra_convs
self.no_norm_on_lateral = no_norm_on_lateral
self.fp16_enabled = False
self.upsample_cfg = upsample_cfg.copy()

if end_level == -1:
self.backbone_end_level = self.num_ins
assert num_outs >= self.num_ins - start_level
else:
# if end_level < inputs, no extra level is allowed
self.backbone_end_level = end_level
assert end_level <= len(in_channels)
assert num_outs == end_level - start_level
self.start_level = start_level
self.end_level = end_level
self.add_extra_convs = add_extra_convs
assert isinstance(add_extra_convs, (str, bool))
if isinstance(add_extra_convs, str):
# Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output'
assert add_extra_convs in ('on_input', 'on_lateral', 'on_output')
elif add_extra_convs: # True
if extra_convs_on_inputs:
# For compatibility with previous release
# TODO: deprecate `extra_convs_on_inputs`
self.add_extra_convs = 'on_input'
else:
self.add_extra_convs = 'on_output'

self.lateral_convs = nn.ModuleList()
self.fpn_convs = nn.ModuleList()

for i in range(self.start_level, self.backbone_end_level):
l_conv = ConvModule(
in_channels[i],
out_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg if not self.no_norm_on_lateral else None,
act_cfg=act_cfg,
inplace=False)
fpn_conv = ConvModule(
out_channels,
out_channels,
3,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
inplace=False)

self.lateral_convs.append(l_conv)
self.fpn_convs.append(fpn_conv)

# add extra conv layers (e.g., RetinaNet)
extra_levels = num_outs - self.backbone_end_level + self.start_level
if self.add_extra_convs and extra_levels >= 1:
for i in range(extra_levels):
if i == 0 and self.add_extra_convs == 'on_input':
in_channels = self.in_channels[self.backbone_end_level - 1]
else:
in_channels = out_channels
extra_fpn_conv = ConvModule(
in_channels,
out_channels,
3,
stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
inplace=False)
self.fpn_convs.append(extra_fpn_conv)

# default init_weights for conv(msra) and norm in ConvModule
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
xavier_init(m, distribution='uniform')

def forward(self, inputs):
assert len(inputs) == len(self.in_channels)

# build laterals
laterals = [
lateral_conv(inputs[i + self.start_level])
for i, lateral_conv in enumerate(self.lateral_convs)
]

# build top-down path
used_backbone_levels = len(laterals)
for i in range(used_backbone_levels - 1, 0, -1):
# In some cases, fixing `scale factor` (e.g. 2) is preferred, but
# it cannot co-exist with `size` in `F.interpolate`.
if 'scale_factor' in self.upsample_cfg:
laterals[i - 1] += F.interpolate(laterals[i],
**self.upsample_cfg)
else:
prev_shape = laterals[i - 1].shape[2:]
laterals[i - 1] += F.interpolate(
laterals[i], size=prev_shape, **self.upsample_cfg)

# build outputs
# part 1: from original levels
outs = [
self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
]
# part 2: add extra levels
if self.num_outs > len(outs):
# use max pool to get more levels on top of outputs
# (e.g., Faster R-CNN, Mask R-CNN)
if not self.add_extra_convs:
for i in range(self.num_outs - used_backbone_levels):
outs.append(F.max_pool2d(outs[-1], 1, stride=2))
# add conv layers on top of original feature maps (RetinaNet)
else:
if self.add_extra_convs == 'on_input':
extra_source = inputs[self.backbone_end_level - 1]
elif self.add_extra_convs == 'on_lateral':
extra_source = laterals[-1]
elif self.add_extra_convs == 'on_output':
extra_source = outs[-1]
else:
raise NotImplementedError
outs.append(self.fpn_convs[used_backbone_levels](extra_source))
for i in range(used_backbone_levels + 1, self.num_outs):
if self.relu_before_extra_convs:
outs.append(self.fpn_convs[i](F.relu(outs[-1])))
else:
outs.append(self.fpn_convs[i](outs[-1]))
return tuple(outs)
Loading

0 comments on commit b8f42c7

Please sign in to comment.