Skip to content

Commit

Permalink
Merge 9c227d7 into 29c82ea
Browse files Browse the repository at this point in the history
  • Loading branch information
MengzhangLI authored Sep 26, 2021
2 parents 29c82ea + 9c227d7 commit b2ee8df
Show file tree
Hide file tree
Showing 15 changed files with 952 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ Supported methods:
- [x] [PointRend (CVPR'2020)](configs/point_rend)
- [x] [CGNet (TIP'2020)](configs/cgnet)
- [x] [SETR (CVPR'2021)](configs/setr)
- [x] [BiSeNetV2 (IJCV'2021)](configs/bisenetv2)
- [x] [SegFormer (ArXiv'2021)](configs/segformer)

Supported datasets:
Expand Down
1 change: 1 addition & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
- [x] [PointRend (CVPR'2020)](configs/point_rend)
- [x] [CGNet (TIP'2020)](configs/cgnet)
- [x] [SETR (CVPR'2021)](configs/setr)
- [x] [BiSeNetV2 (IJCV'2021)](configs/bisenetv2)
- [x] [SegFormer (ArXiv'2021)](configs/segformer)

已支持的数据集:
Expand Down
35 changes: 35 additions & 0 deletions configs/_base_/datasets/cityscapes_1024x1024.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
_base_ = './cityscapes.py'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (1024, 1024)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2048, 1024),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
80 changes: 80 additions & 0 deletions configs/_base_/models/bisenetv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained=None,
backbone=dict(
type='BiSeNetV2',
detail_channels=(64, 64, 128),
semantic_channels=(16, 32, 64, 128),
semantic_expansion_ratio=6,
bga_channels=128,
out_indices=(0, 1, 2, 3, 4),
init_cfg=None,
align_corners=False),
decode_head=dict(
type='FCNHead',
in_channels=128,
in_index=0,
channels=1024,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=[
dict(
type='FCNHead',
in_channels=16,
channels=16,
num_convs=2,
num_classes=19,
in_index=1,
norm_cfg=norm_cfg,
concat_input=False,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
dict(
type='FCNHead',
in_channels=32,
channels=64,
num_convs=2,
num_classes=19,
in_index=2,
norm_cfg=norm_cfg,
concat_input=False,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
dict(
type='FCNHead',
in_channels=64,
channels=256,
num_convs=2,
num_classes=19,
in_index=3,
norm_cfg=norm_cfg,
concat_input=False,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
dict(
type='FCNHead',
in_channels=128,
channels=1024,
num_convs=2,
num_classes=19,
in_index=4,
norm_cfg=norm_cfg,
concat_input=False,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
],
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
33 changes: 33 additions & 0 deletions configs/bisenetv2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Bisenet v2: Bilateral Network with Guided Aggregation for Real-time Semantic Segmentation

## Introduction

<!-- [ALGORITHM] -->

```latex
@article{yu2021bisenet,
title={Bisenet v2: Bilateral network with guided aggregation for real-time semantic segmentation},
author={Yu, Changqian and Gao, Changxin and Wang, Jingbo and Yu, Gang and Shen, Chunhua and Sang, Nong},
journal={International Journal of Computer Vision},
pages={1--18},
year={2021},
publisher={Springer}
}
```

## Results and models

### Cityscapes

| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
| ------ | -------- | --------- | ------: | -------- | -------------- | ----: | ------------: | --------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| BiSeNetV2 | BiSeNetV2 | 1024x1024 | 160000 | 7.64 | 31.77 | 73.21 | 75.74 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/bisenetv2/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes_20210902_015551-bcf10f09.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes_20210902_015551.log.json) |
| BiSeNetV2 (OHEM) | BiSeNetV2 | 1024x1024 | 160000 | 7.64 | - | 73.57 | 75.80 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/bisenetv2/bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes/bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes_20210902_112947-5f8103b4.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes/bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes_20210902_112947.log.json) |
| BiSeNetV2 (4x8) | BiSeNetV2 | 1024x1024 | 160000 | 15.05 | - | 75.76 | 77.79 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/bisenetv2/bisenetv2_fcn_4x8_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_4x8_1024x1024_160k_cityscapes/bisenetv2_fcn_4x8_1024x1024_160k_cityscapes_20210903_000032-e1a2eed6.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_4x8_1024x1024_160k_cityscapes/bisenetv2_fcn_4x8_1024x1024_160k_cityscapes_20210903_000032.log.json) |
| BiSeNetV2 (FP16) | BiSeNetV2 | 1024x1024 | 160000 | 5.77 | 36.65 | 73.07 | 75.13 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/bisenetv2/bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes/bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes_20210902_045942-b979777b.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes/bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes_20210902_045942.log.json) |

Note:

- `OHEM` means Online Hard Example Mining (OHEM) is adopted in training.
- `FP16` means Mixed Precision (FP16) is adopted in training.
- `4x8` means 4 GPUs with 8 samples per GPU in training.
80 changes: 80 additions & 0 deletions configs/bisenetv2/bisenetv2.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
Collections:
- Metadata:
Training Data:
- Cityscapes
Name: bisenetv2
Models:
- Config: configs/bisenetv2/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes.py
In Collection: bisenetv2
Metadata:
backbone: BiSeNetV2
crop size: (1024,1024)
inference time (ms/im):
- backend: PyTorch
batch size: 1
hardware: V100
mode: FP32
resolution: (1024,1024)
value: 31.48
lr schd: 160000
memory (GB): 7.64
Name: bisenetv2_fcn_4x4_1024x1024_160k_cityscapes
Results:
Dataset: Cityscapes
Metrics:
mIoU: 73.21
mIoU(ms+flip): 75.74
Task: Semantic Segmentation
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes_20210902_015551-bcf10f09.pth
- Config: configs/bisenetv2/bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes.py
In Collection: bisenetv2
Metadata:
backbone: BiSeNetV2
crop size: (1024,1024)
lr schd: 160000
memory (GB): 7.64
Name: bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes
Results:
Dataset: Cityscapes
Metrics:
mIoU: 73.57
mIoU(ms+flip): 75.8
Task: Semantic Segmentation
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes/bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes_20210902_112947-5f8103b4.pth
- Config: configs/bisenetv2/bisenetv2_fcn_4x8_1024x1024_160k_cityscapes.py
In Collection: bisenetv2
Metadata:
backbone: BiSeNetV2
crop size: (1024,1024)
lr schd: 160000
memory (GB): 15.05
Name: bisenetv2_fcn_4x8_1024x1024_160k_cityscapes
Results:
Dataset: Cityscapes
Metrics:
mIoU: 75.76
mIoU(ms+flip): 77.79
Task: Semantic Segmentation
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_4x8_1024x1024_160k_cityscapes/bisenetv2_fcn_4x8_1024x1024_160k_cityscapes_20210903_000032-e1a2eed6.pth
- Config: configs/bisenetv2/bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes.py
In Collection: bisenetv2
Metadata:
backbone: BiSeNetV2
crop size: (1024,1024)
inference time (ms/im):
- backend: PyTorch
batch size: 1
hardware: V100
mode: FP32
resolution: (1024,1024)
value: 27.29
lr schd: 160000
memory (GB): 5.77
Name: bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes
Results:
Dataset: Cityscapes
Metrics:
mIoU: 73.07
mIoU(ms+flip): 75.13
Task: Semantic Segmentation
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes/bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes_20210902_045942-b979777b.pth
11 changes: 11 additions & 0 deletions configs/bisenetv2/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
_base_ = [
'../_base_/models/bisenetv2.py',
'../_base_/datasets/cityscapes_1024x1024.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
lr_config = dict(warmup='linear', warmup_iters=1000)
optimizer = dict(lr=0.05)
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
)
11 changes: 11 additions & 0 deletions configs/bisenetv2/bisenetv2_fcn_4x8_1024x1024_160k_cityscapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
_base_ = [
'../_base_/models/bisenetv2.py',
'../_base_/datasets/cityscapes_1024x1024.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
lr_config = dict(warmup='linear', warmup_iters=1000)
optimizer = dict(lr=0.05)
data = dict(
samples_per_gpu=8,
workers_per_gpu=8,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_base_ = './bisenetv2_fcn_4x4_1024x1024_160k_cityscapes.py'
# fp16 settings
optimizer_config = dict(type='Fp16OptimizerHook', loss_scale=512.)
# fp16 placeholder
fp16 = dict()
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
_base_ = [
'../_base_/models/bisenetv2.py',
'../_base_/datasets/cityscapes_1024x1024.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
sampler = dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000)
lr_config = dict(warmup='linear', warmup_iters=1000)
optimizer = dict(lr=0.05)
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
)
3 changes: 2 additions & 1 deletion mmseg/models/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .bisenetv2 import BiSeNetV2
from .cgnet import CGNet
from .fast_scnn import FastSCNN
from .hrnet import HRNet
Expand All @@ -15,5 +16,5 @@
__all__ = [
'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN',
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer'
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer', 'BiSeNetV2'
]
Loading

0 comments on commit b2ee8df

Please sign in to comment.