Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support BiSeNetV1 #851

Merged
merged 14 commits into from
Sep 28, 2021
35 changes: 35 additions & 0 deletions configs/_base_/datasets/cityscapes_1024x1024.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
_base_ = './cityscapes.py'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (1024, 1024)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2048, 1024),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
68 changes: 68 additions & 0 deletions configs/_base_/models/bisenetv1_r18-d32.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
backbone=dict(
type='BiSeNetV1',
in_channels=3,
context_channels=(128, 256, 512),
spatial_channels=(64, 64, 64, 128),
out_indices=(0, 1, 2),
out_channels=256,
backbone_cfg=dict(
type='ResNet',
in_channels=3,
depth=18,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 1, 1),
strides=(1, 2, 2, 2),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
norm_cfg=norm_cfg,
align_corners=False,
init_cfg=None),
decode_head=dict(
type='FCNHead',
in_channels=256,
in_index=0,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=[
dict(
type='FCNHead',
in_channels=128,
channels=64,
num_convs=1,
num_classes=19,
in_index=1,
norm_cfg=norm_cfg,
concat_input=False,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
dict(
type='FCNHead',
in_channels=128,
channels=64,
num_convs=1,
num_classes=19,
in_index=2,
norm_cfg=norm_cfg,
concat_input=False,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
],
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
29 changes: 29 additions & 0 deletions configs/bisenetv1/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# BiSeNet: Bilateral Segmentation Network for Real-time Semantic Segmentation

## Introduction

<!-- [ALGORITHM] -->

```latex
@inproceedings{yu2018bisenet,
title={Bisenet: Bilateral segmentation network for real-time semantic segmentation},
author={Yu, Changqian and Wang, Jingbo and Peng, Chao and Gao, Changxin and Yu, Gang and Sang, Nong},
booktitle={Proceedings of the European conference on computer vision (ECCV)},
pages={325--341},
year={2018}
}
```

## Results and models

### Cityscapes

| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
| --------- | --------- | --------- | ------: | -------- | -------------- | ----: | ------------- | --------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| BiSeNetV1 (4x4) | R-18-D32 | 1024x1024 | 160000 | 3.3 | 31.77 | 74.37 | 76.91 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/bisenetv1/bisenetv1_r18-d32_4x4_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/fast_scnn/fast_scnn_lr0.12_8x4_160k_cityscapes/fast_scnn_lr0.12_8x4_160k_cityscapes_20210630_164853-0cec9937.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/fast_scnn/fast_scnn_lr0.12_8x4_160k_cityscapes/fast_scnn_lr0.12_8x4_160k_cityscapes_20210630_164853.log.json) |
| BiSeNetV1 (4x8) | R-18-D32 | 1024x1024 | 160000 | 3.3 | - | 75.16 | 77.24 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/bisenetv1/bisenetv1_r18-d32_4x8_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/fast_scnn/fast_scnn_lr0.12_8x4_160k_cityscapes/fast_scnn_lr0.12_8x4_160k_cityscapes_20210630_164853-0cec9937.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/fast_scnn/fast_scnn_lr0.12_8x4_160k_cityscapes/fast_scnn_lr0.12_8x4_160k_cityscapes_20210630_164853.log.json) |

Note:

- `4x4`: Using 4 GPUs with 4 samples per GPU in training.
- `4x8`: Using 4 GPUs with 8 samples per GPU in training.
43 changes: 43 additions & 0 deletions configs/bisenetv1/bisenetv1.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
Collections:
- Metadata:
Training Data:
- Cityscapes
Name: bisenetv1
Models:
- Config: configs/bisenetv1/bisenetv1_r18-d32_4x4_1024x1024_160k_cityscapes.py
In Collection: bisenetv1
Metadata:
backbone: R-18-D32
crop size: (1024,1024)
inference time (ms/im):
- backend: PyTorch
batch size: 1
hardware: V100
mode: FP32
resolution: (1024,1024)
value: 31.48
lr schd: 160000
memory (GB): 3.3
Name: bisenetv1_r18-d32_4x4_1024x1024_160k_cityscapes
Results:
Dataset: Cityscapes
Metrics:
mIoU: 74.37
mIoU(ms+flip): 76.91
Task: Semantic Segmentation
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/fast_scnn/fast_scnn_lr0.12_8x4_160k_cityscapes/fast_scnn_lr0.12_8x4_160k_cityscapes_20210630_164853-0cec9937.pth
- Config: configs/bisenetv1/bisenetv1_r18-d32_4x8_1024x1024_160k_cityscapes.py
In Collection: bisenetv1
Metadata:
backbone: R-18-D32
crop size: (1024,1024)
lr schd: 160000
memory (GB): 3.3
Name: bisenetv1_r18-d32_4x8_1024x1024_160k_cityscapes
Results:
Dataset: Cityscapes
Metrics:
mIoU: 75.16
mIoU(ms+flip): 77.24
Task: Semantic Segmentation
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/fast_scnn/fast_scnn_lr0.12_8x4_160k_cityscapes/fast_scnn_lr0.12_8x4_160k_cityscapes_20210630_164853-0cec9937.pth
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
_base_ = [
'../_base_/models/bisenetv1_r18-d32.py',
'../_base_/datasets/cityscapes_1024x1024.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
lr_config = dict(warmup='linear', warmup_iters=1000)
optimizer = dict(lr=0.025)
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
_base_ = [
'../_base_/models/bisenetv1_r18-d32.py',
'../_base_/datasets/cityscapes_1024x1024.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
lr_config = dict(warmup='linear', warmup_iters=1000)
optimizer = dict(lr=0.025)
data = dict(
samples_per_gpu=8,
workers_per_gpu=8,
)
3 changes: 2 additions & 1 deletion mmseg/models/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .bisenetv1 import BiSeNetV1
from .cgnet import CGNet
from .fast_scnn import FastSCNN
from .hrnet import HRNet
Expand All @@ -15,5 +16,5 @@
__all__ = [
'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN',
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer'
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer', 'BiSeNetV1'
]
Loading