Skip to content

Commit

Permalink
[Feature] Support K-Net (open-mmlab#1289)
Browse files Browse the repository at this point in the history
* knet first commit

* fix import error in knet

* remove kernel update head from decoder head

* [Feature] Add kenerl updation for some decoder heads.

* [Feature] Add kenerl updation for some decoder heads.

* directly use forward_feature && modify other 3 decoder heads

* remover kernel_update attr

* delete unnecessary variables in forward function

* delete kernel update function

* delete kernel update function

* delete kernel_generate_head

* add unit test & comments in knet.py

* add copyright to fix lint error

* modify config names of knet

* rename swin-l 640

* upload models&logs and refactor knet_head.py

* modify docstrings and add some ut

* add url, modify docstring and add loss ut

* modify docstrings
  • Loading branch information
MengzhangLI authored and mob5566 committed Apr 13, 2022
1 parent 06647b7 commit ec4bec1
Show file tree
Hide file tree
Showing 15 changed files with 1,373 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ Supported methods:
- [x] [DPT (ArXiv'2021)](configs/dpt)
- [x] [Segmenter (ICCV'2021)](configs/segmenter)
- [x] [SegFormer (NeurIPS'2021)](configs/segformer)
- [x] [K-Net (NeurIPS'2021)](configs/knet)

Supported datasets:

Expand Down
1 change: 1 addition & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
- [x] [DPT (ArXiv'2021)](configs/dpt)
- [x] [Segmenter (ICCV'2021)](configs/segmenter)
- [x] [SegFormer (NeurIPS'2021)](configs/segformer)
- [x] [K-Net (NeurIPS'2021)](configs/knet)

已支持的数据集:

Expand Down
49 changes: 49 additions & 0 deletions configs/knet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# K-Net

[K-Net: Towards Unified Image Segmentation](https://arxiv.org/abs/2106.14855)

## Introduction

<!-- [ALGORITHM] -->

<a href="https://github.com/ZwwWayne/K-Net/">Official Repo</a>

<a href="https://github.com/open-mmlab/mmsegmentation/blob/v0.23.0/mmseg/models/decode_heads/knet_head.py#L392">Code Snippet</a>

## Abstract

<!-- [ABSTRACT] -->

Semantic, instance, and panoptic segmentations have been addressed using different and specialized frameworks despite their underlying connections. This paper presents a unified, simple, and effective framework for these essentially similar tasks. The framework, named K-Net, segments both instances and semantic categories consistently by a group of learnable kernels, where each kernel is responsible for generating a mask for either a potential instance or a stuff class. To remedy the difficulties of distinguishing various instances, we propose a kernel update strategy that enables each kernel dynamic and conditional on its meaningful group in the input image. K-Net can be trained in an end-to-end manner with bipartite matching, and its training and inference are naturally NMS-free and box-free. Without bells and whistles, K-Net surpasses all previous published state-of-the-art single-model results of panoptic segmentation on MS COCO test-dev split and semantic segmentation on ADE20K val split with 55.2% PQ and 54.3% mIoU, respectively. Its instance segmentation performance is also on par with Cascade Mask R-CNN on MS COCO with 60%-90% faster inference speeds. Code and models will be released at [this https URL](https://github.com/ZwwWayne/K-Net/).

<!-- [IMAGE] -->
<div align=center>
<img src="https://user-images.githubusercontent.com/24582831/157008300-9f40905c-b8e8-4a2a-9593-c1177fa35b2c.png" width="90%"/>
</div>

```bibtex
@inproceedings{zhang2021knet,
title={{K-Net: Towards} Unified Image Segmentation},
author={Wenwei Zhang and Jiangmiao Pang and Kai Chen and Chen Change Loy},
year={2021},
booktitle={NeurIPS},
}
```

## Results and models

### ADE20K

| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
| --------------- | -------- | --------- | ------- | -------- | -------------- | ----- | ------------- | ---------------------------------------------------------------------------------------------------------------------------------------- | ----- |
| KNet + FCN | R-50-D8 | 512x512 | 80000 | 7.01 | 19.24 | 43.60 | 45.12 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/knet/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k_20220228_043751-abcab920.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k_20220228_043751.log.json) |
| KNet + PSPNet | R-50-D8 | 512x512 | 80000 | 6.98 | 20.04 | 44.18 | 45.58 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/knet/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k_20220228_054634-d2c72240.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k_20220228_054634.log.json) |
| KNet + DeepLabV3| R-50-D8 | 512x512 | 80000 | 7.42 | 12.10 | 45.06 | 46.11 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/knet/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k_20220228_041642-00c8fbeb.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k_20220228_041642.log.json) |
| KNet + UperNet | R-50-D8 | 512x512 | 80000 | 7.34 | 17.11 | 43.45 | 44.07 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/knet/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k_20220304_125657-215753b0.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k_20220304_125657.log.json) |
| KNet + UperNet | Swin-T | 512x512 | 80000 | 7.57 | 15.56 | 45.84 | 46.27 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/knet/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k_20220303_133059-7545e1dc.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k_20220303_133059.log.json) |
| KNet + UperNet | Swin-L | 512x512 | 80000 | 13.5 | 8.29 | 52.05 | 53.24 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/knet/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k_20220303_154559-d8da9a90.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k_20220303_154559.log.json) |
| KNet + UperNet | Swin-L | 640x640 | 80000 | 13.54 | 8.29 | 52.21 | 53.34 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/knet/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k_20220301_220747-8787fc71.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k_20220301_220747.log.json) |

Note:

- All experiments of K-Net are implemented with 8 V100 (32G) GPUs with 2 samplers per GPU.
169 changes: 169 additions & 0 deletions configs/knet/knet.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
Collections:
- Name: KNet
Metadata:
Training Data:
- ADE20K
Paper:
URL: https://arxiv.org/abs/2106.14855
Title: 'K-Net: Towards Unified Image Segmentation'
README: configs/knet/README.md
Code:
URL: https://github.com/open-mmlab/mmsegmentation/blob/v0.23.0/mmseg/models/decode_heads/knet_head.py#L392
Version: v0.23.0
Converted From:
Code: https://github.com/ZwwWayne/K-Net/
Models:
- Name: knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k
In Collection: KNet
Metadata:
backbone: R-50-D8
crop size: (512,512)
lr schd: 80000
inference time (ms/im):
- value: 51.98
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 7.01
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 43.6
mIoU(ms+flip): 45.12
Config: configs/knet/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_fcn_r50-d8_8x2_512x512_adamw_80k_ade20k_20220228_043751-abcab920.pth
- Name: knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k
In Collection: KNet
Metadata:
backbone: R-50-D8
crop size: (512,512)
lr schd: 80000
inference time (ms/im):
- value: 49.9
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 6.98
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 44.18
mIoU(ms+flip): 45.58
Config: configs/knet/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_pspnet_r50-d8_8x2_512x512_adamw_80k_ade20k_20220228_054634-d2c72240.pth
- Name: knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k
In Collection: KNet
Metadata:
backbone: R-50-D8
crop size: (512,512)
lr schd: 80000
inference time (ms/im):
- value: 82.64
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 7.42
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 45.06
mIoU(ms+flip): 46.11
Config: configs/knet/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_deeplabv3_r50-d8_8x2_512x512_adamw_80k_ade20k_20220228_041642-00c8fbeb.pth
- Name: knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k
In Collection: KNet
Metadata:
backbone: R-50-D8
crop size: (512,512)
lr schd: 80000
inference time (ms/im):
- value: 58.45
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 7.34
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 43.45
mIoU(ms+flip): 44.07
Config: configs/knet/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k/knet_s3_upernet_r50-d8_8x2_512x512_adamw_80k_ade20k_20220304_125657-215753b0.pth
- Name: knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k
In Collection: KNet
Metadata:
backbone: Swin-T
crop size: (512,512)
lr schd: 80000
inference time (ms/im):
- value: 64.27
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 7.57
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 45.84
mIoU(ms+flip): 46.27
Config: configs/knet/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k/knet_s3_upernet_swin-t_8x2_512x512_adamw_80k_ade20k_20220303_133059-7545e1dc.pth
- Name: knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k
In Collection: KNet
Metadata:
backbone: Swin-L
crop size: (512,512)
lr schd: 80000
inference time (ms/im):
- value: 120.63
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 13.5
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 52.05
mIoU(ms+flip): 53.24
Config: configs/knet/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k/knet_s3_upernet_swin-l_8x2_512x512_adamw_80k_ade20k_20220303_154559-d8da9a90.pth
- Name: knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k
In Collection: KNet
Metadata:
backbone: Swin-L
crop size: (640,640)
lr schd: 80000
inference time (ms/im):
- value: 120.63
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (640,640)
Training Memory (GB): 13.54
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 52.21
mIoU(ms+flip): 53.34
Config: configs/knet/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/knet/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k/knet_s3_upernet_swin-l_8x2_640x640_adamw_80k_ade20k_20220301_220747-8787fc71.pth
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
_base_ = [
'../_base_/datasets/ade20k.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_80k.py'
]

# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
num_stages = 3
conv_kernel_size = 1
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='IterativeDecodeHead',
num_stages=num_stages,
kernel_update_head=[
dict(
type='KernelUpdateHead',
num_classes=150,
num_ffn_fcs=2,
num_heads=8,
num_mask_fcs=1,
feedforward_channels=2048,
in_channels=512,
out_channels=512,
dropout=0.0,
conv_kernel_size=conv_kernel_size,
ffn_act_cfg=dict(type='ReLU', inplace=True),
with_ffn=True,
feat_transform_cfg=dict(
conv_cfg=dict(type='Conv2d'), act_cfg=None),
kernel_updator_cfg=dict(
type='KernelUpdator',
in_channels=256,
feat_channels=256,
out_channels=256,
act_cfg=dict(type='ReLU', inplace=True),
norm_cfg=dict(type='LN'))) for _ in range(num_stages)
],
kernel_generate_head=dict(
type='ASPPHead',
in_channels=2048,
in_index=3,
channels=512,
dilations=(1, 12, 24, 36),
dropout_ratio=0.1,
num_classes=150,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=150,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))

# optimizer
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.0005)
optimizer_config = dict(grad_clip=dict(max_norm=1, norm_type=2))
# learning policy
lr_config = dict(
_delete_=True,
policy='step',
warmup='linear',
warmup_iters=1000,
warmup_ratio=0.001,
step=[60000, 72000],
by_epoch=False)
# In K-Net implementation we use batch size 2 per GPU as default
data = dict(samples_per_gpu=2, workers_per_gpu=2)
Loading

0 comments on commit ec4bec1

Please sign in to comment.